#include <bits/stdc++.h>
using namespace std;
const int maxn = 200005;
const long long base = 35711;
const long long mod  = 1e9 + 7;
int n, Len, maxDep, child[maxn], removed[maxn];
char a[maxn];
vector<pair<int, int>> b;
long long pw[maxn];
vector<int> adj[maxn];
map<int, bool> f[maxn];
void countChild(int u, int p){
    child[u] = 1;
    for (int v : adj[u]) if(v != p && !removed[v]){
        countChild(v, u);
        child[u] += child[v];
    }
}
bool dfs(int u, int p, int h, long long hshdown, long long hshup)
{
    if (h > Len) return false;
    if (p)
        hshdown = (hshdown * base + a[u]) % mod;
    hshup = (hshup + 1LL * a[u] * pw[h - 1]) % mod;
    long long x =  (hshup * pw[Len - h] - hshdown + mod) % mod;
    if (!p) f[h][x] = true;
    if (f[Len - h + 1].find(x) != f[Len - h + 1].end() )
        return true;
    for(int v: adj[u]) if(v != p && !removed[v]){
        if(!p) b.clear();
        if(dfs(v, u, h + 1, hshdown, hshup)) return true;
        if(!p)
            for(auto [height, val]: b) f[height][val] = 1;
    }
    maxDep = max(maxDep, h);
    b.push_back({h, x});
    return false;
}
bool CD(int u, int n){
    countChild(u, 0);
    int flag = 1, half = n / 2;
    while(flag){
        flag = 0;
        for (int v : adj[u])
            if (!removed[v] && child[v] < child[u] && child[v] > half){
                u = v;
                flag = 1;
                break;
            }
    }
    if(dfs(u, 0, 1, 0, 0)) return true;
    for(int i = 1; i <= maxDep; i++) f[i].clear();
    maxDep = 0;
    removed[u] = true;
    for(int v: adj[u]){
        if(!removed[v]){
            if(CD(v, child[v])) return true;
        }
    }
    return false;
}
bool check(int len){
    Len = len;
    for(int i = 1; i <= n; i++) removed[i] = 0, f[i].clear();
    return CD(1, n);
}
void solve(){
    cin >> n;
    for(int i = 1; i <= n; i++) cin >> a[i];
    for(int i = 1; i <= n - 1; i++){
        int u, v; cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    pw[0] = 1;
    for(int i = 1; i <= n; i++) pw[i] = pw[i - 1] * base % mod;
    int l = 0, r = n, ans = 0;
    while(l <= r){
        int g = (l + r) / 2;
        if (check(g * 2 + 1)) l = g + 1, ans = max(ans, 2 * g + 1);
        else r = g - 1;
    }
    l = 1, r = n;
    while(l <= r){
        int g = (l + r) / 2;
        if (check(g * 2)) l = g + 1, ans = max(ans, g * 2); 
        else r = g - 1;
    }
    cout << ans;
}
int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    solve();
    return 0;
}
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... |