#include "bits/stdc++.h"
// @JASPER'S BOILERPLATE
using namespace std;
using ll = long long;
#ifdef JASPER
#include "debug.h"
#else
#define debug(...) 166
#endif
using ull = unsigned long long;
const int N = 5e4 + 5;
const int B = 31;
const int MOD = 1e9 + 7;
int n;
string s;
vector <int> adj[N];
int c[N];
int pw[N];
// Centroid template
int sz[N], dep[N];
bool vis[N];
int Tsize;
void reSubsize(int u, int p) {
sz[u] = 1;
for (int v : adj[u]) {
if (v != p && !vis[v]) {
reSubsize(v, u);
sz[u] += sz[v];
}
}
}
int getCentroid(int u, int p) {
for (int v : adj[u]) {
if (v != p && !vis[v] && sz[v] > Tsize / 2)
return getCentroid(v, u);
}
return u;
}
int maxDep;
int hashU[N];
int hashD[N];
int value[N];
int nNode, node[N];
map<int, bool> mp[N];
int k; // current length
void add(int u, int p, bool &found) {
dep[u] = dep[p] + 1;
if (dep[u] + 1 > k) return;
maxDep = max(maxDep, dep[u]);
hashU[u] = (1LL * hashU[p] * B + (s[u] - 'a' + 1)) % MOD;
hashD[u] = (hashD[p] + 1LL * pw[dep[u]] * (s[u] - 'a' + 1)) % MOD;
int need = (1LL * pw[k - dep[u] - 1] * hashD[u] - hashU[u] + MOD) % MOD;
if (mp[k - dep[u] - 1].find(need) != mp[k - dep[u] - 1].end()) {
found = true;
return;
}
if (dep[u] + 1 == k && hashD[u] == (int) (hashU[u] + 1LL * pw[dep[u]] * (s[u] - 'a' + 1)) % MOD) {
found = true;
return;
}
node[++nNode] = u; value[u] = need;
for (int v: adj[u]) if (v != p && !vis[v]) {
add(v, u, found);
if (found) return;
}
}
void solve(int u, int p, bool &found) {
reSubsize(u, p);
int x = getCentroid(u, p);
vis[x] = true;
dep[x] = 0;
nNode = 0;
maxDep = 0;
hashU[x] = 0;
hashD[x] = s[x] - 'a' + 1;
for (int v: adj[x]) {
if (v != p && !vis[v]) {
add(v, x, found);
if (found) return;
while (nNode > 0) {
int y = node[nNode--];
mp[dep[y]][value[y]] = 1;
}
}
}
for (int i = 0; i <= maxDep; ++i)
mp[i].clear();
for (int v: adj[x]) {
if (v != p && !vis[v]) {
solve(v, x, found);
if (found) return;
}
}
}
bool func(int x) {
memset(vis, false, sizeof vis);
bool found = false;
k = x;
solve(1, 0, found);
return found;
}
signed main() {
cin.tie(0) -> sync_with_stdio(0);
cin >> n >> s;
s = "@" + s;
for (int i = 1; i < n; ++i) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
for (int i = 1; i <= n; ++i)
c[i] = s[i] - 'a' + 1;
pw[0] = 1;
for (int i = 1; i <= n; ++i) {
pw[i] = (pw[i - 1] * B) % MOD;
}
int ans = 1;
for (int i = 0; i <= 1; ++i) {
int l = 0, r = n / 2 + 1;
while (l + 1 < r) {
int mid = (l + r) / 2;
if (func(i + 2 * mid))
l = mid;
else
r = mid;
}
ans = max(ans, i + 2 * l);
}
cout << ans << "\n";
return 0;
}
// Centroid: path of equal length
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
30 ms |
5456 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
5060 ms |
12040 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
5039 ms |
11336 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
30 ms |
5456 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |