#include <bits/stdc++.h>
#define all(v) begin(v), end(v)
#define dbg(x) "[" #x " = " << x << "]"
#pragma GCC optimize("O3,unroll-loops,inline-functions,fast-math")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt,tune=native")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("no-stack-protector")
#pragma GCC optimize("fast-math")
#pragma GCC optimize("omit-frame-pointer")
#pragma GCC optimize("inline-small-functions")
#pragma GCC optimize("inline-functions-called-once")
bool M1;
using namespace std;
const int MAXN = 50005, mod = 1e9 + 7, base = 256;
int numNode, mid, sz[MAXN], maxDist = 0;
long long pre[MAXN], powB[MAXN], inv[MAXN];
bool del[MAXN], isPalind[MAXN];
char c[MAXN];
set<int> st[MAXN];
vector<int> adj[MAXN];
long long bin_pow(long long a, long long k){
long long res = 1;
while(k){
if (k & 1) res = (res * a) % mod;
a = (a * a) % mod; k >>= 1;
}
return res;
}
int getSize(int u, int par = 0){
sz[u] = 1;
for(int v: adj[u]) if (v != par && !del[v]){
sz[u] += getSize(v, u);
}
return sz[u];
}
int getCentroid(int sizeU, int u, int par = 0){
for(int v: adj[u]) if (v != par && !del[v]){
if (sz[v] > (sizeU >> 1)) return getCentroid(sizeU, v, u);
}
return u;
}
int changeType(char c){
return c - 'a' + 1;
}
long long getHash(int L, int R){
long long res = ((pre[R] - pre[L - 1] + mod) % mod * inv[L - 1]) % mod;
return res;
}
bool check(int u, long long sumDown, long long sumUp, int dist = 2, int par = 0){
if (dist > mid) return 0;
sumDown = (sumDown + changeType(c[u]) * powB[dist - 1] % mod) % mod;
sumUp = (sumUp * base) % mod;
sumUp = (sumUp + changeType(c[u])) % mod;
pre[dist] = sumDown;
isPalind[dist] = (sumUp == sumDown);
if ((dist << 1) >= mid){
if (isPalind[dist - (mid - dist)]){
if (st[mid - dist].count(getHash(dist - (mid - dist) + 1, dist))){
return 1;
}
}
}
for(int v: adj[u]) if (!del[v] && v != par){
if (check(v, sumDown, sumUp, dist + 1, u)) return 1;
}
isPalind[dist] = 0;
return 0;
}
void update(int u, long long sumDown, int dist = 1, int par = 0){
if (dist > mid) return;
sumDown = (sumDown + changeType(c[u]) * powB[dist - 1] % mod) % mod;
maxDist = max(maxDist, dist);
st[dist].insert(sumDown);
for(int v: adj[u]) if (!del[v] && v != par){
update(v, sumDown, dist + 1, u);
}
}
bool decomp(int u = 1){
int centroid = getCentroid(getSize(u), u);
del[centroid] = 1;
st[0].insert(0); maxDist = 0;
for(int v: adj[centroid]) if (!del[v]){
long long sumDown = changeType(c[centroid]);
long long sumUp = changeType(c[centroid]);
pre[1] = changeType(c[centroid]);
if (check(v, sumDown, sumUp)) return 1;
update(v, 0);
}
for(int dist = 0; dist <= maxDist; dist++) st[dist].clear();
reverse(all(adj[centroid]));
for(int v: adj[centroid]) if (!del[v]){
long long sumDown = changeType(c[centroid]);
long long sumUp = changeType(c[centroid]);
pre[1] = changeType(c[centroid]);
if (check(v, sumDown, sumUp)) return 1;
update(v, 0);
}
for(int dist = 0; dist <= maxDist; dist++) st[dist].clear();
for(int v: adj[centroid]) if (!del[v]){
if (decomp(v)) return 1;
}
return 0;
}
bool check(const int &len){
mid = len;
if (len > numNode) return 0;
if (len < 2) return 1;
memset(isPalind, 0, sizeof isPalind);
isPalind[0] = isPalind[1] = 1;
memset(del, 0, sizeof del);
memset(pre, 0, sizeof pre);
for(int dist = 0; dist <= numNode; dist++) st[dist].clear();
return decomp();
}
int calc(int type){
int l = 0, r = (numNode) >> 1, m, res = -1;
while(l <= r){
m = (l + r) >> 1;
if (check(2 * m + type)) res = 2 * m + type, l = m + 1;
else r = m - 1;
}
return res;
}
void init(){
powB[0] = 1;
inv[0] = 1;
for(int i = 1; i <= numNode; i++){
powB[i] = (powB[i - 1] * base) % mod;
inv[i] = bin_pow(powB[i], mod - 2);
}
}
void input(){
cin >> numNode;
for(int i = 1; i <= numNode; i++) cin >> c[i];
int u, v;
for(int i = 1; i < numNode; i++){
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
}
void solve(){
cout << max(calc(0), calc(1)) << '\n';
}
int main(){
ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
if (fopen("test.inp", "r")){
freopen("test.inp", "r", stdin);
freopen("test.out", "w", stdout);
}
input();
init();
solve();
bool M2;
cerr << abs(&M2 - &M1) / 1048576 << " MB\n";
cerr << (1.0 * clock()) / CLOCKS_PER_SEC << ".s\n";
}
컴파일 시 표준 에러 (stderr) 메시지
lampice.cpp: In function 'int main()':
lampice.cpp:162:16: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
162 | freopen("test.inp", "r", stdin);
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~
lampice.cpp:163:16: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
163 | freopen("test.out", "w", stdout);
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |