#include <bits/stdc++.h>
using namespace std;
const int mod = 1e9 + 9;
const int base = 29;
int add(int x, int y)
{
x += y;
if (x >= mod)
{
return x - mod;
}
return x;
}
int mult(int x, int y)
{
return (int64_t)x * y % mod;
}
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
int random(int st, int dr)
{
uniform_int_distribution<mt19937::result_type> gen(st, dr);
return gen(rng);
}
struct lampice
{
int n;
vector<vector<int>> g;
vector<char> colors;
vector<bool> seen;
vector<int> sz;
vector<int> depth;
vector<int> hashup;
vector<int> hashdown;
vector<int> par;
vector<int> nodes;
void init(int _n)
{
n = _n;
g = vector<vector<int>>(n + 1);
colors = vector<char>(n + 1);
seen = vector<bool>(n + 1);
nodes = sz = depth = hashup = hashdown = par = vector<int>(n + 1);
}
void set_color(int pos, char x)
{
colors[pos] = x;
}
void add_edge(int a, int b)
{
g[a].push_back(b);
g[b].push_back(a);
}
void dfs_size(int node, int parent)
{
sz[node] = 1;
for (auto i : g[node])
{
if (i != parent && !seen[i])
{
dfs_size(i, node);
sz[node] += sz[i];
}
}
}
int find_centroid(int node, int parent, int size)
{
for (auto i : g[node])
{
if (i != parent && !seen[i] && sz[i] > size / 2)
{
return find_centroid(i, node, size);
}
}
return node;
}
int solve(int node, int k)
{
int max_depth = 0;
function<void(int, int, int)> dfs_init = [&](int node, int parent, int d)
{
par[node] = parent;
depth[node] = d;
max_depth = max(max_depth, d);
for (auto i : g[node])
{
if (i != parent && !seen[i])
{
dfs_init(i, node, d + 1);
}
}
};
dfs_init(node, 0, 0);
vector<int> power(max_depth + 1);
power[0] = 1;
for (int i = 1; i <= max_depth; ++i)
{
power[i] = mult(power[i - 1], base);
}
function<void(int, int)> compute_hash = [&](int node, int parent)
{
hashup[node] = add(mult(base, hashup[parent]), colors[node] - 'a' + 1);
hashdown[node] = add(hashdown[parent], mult(power[depth[node]], (colors[node] - 'a' + 1)));
for (auto i : g[node])
{
if (i != parent && !seen[i])
{
compute_hash(i, node);
}
}
};
compute_hash(node, 0);
function<int(int, int)> get_hashup = [&](int a, int b)
{
int c = par[b];
return add(hashup[a], mod - mult(hashup[c], power[depth[a] - depth[c]]));
};
function<bool(int)> check = [&](int k)
{
unordered_multiset<int> exista;
function<void(int, int, int)> add_subtree = [&](int node, int parent, int root)
{
exista.insert(get_hashup(node, root));
for (auto i : g[node])
{
if (i != parent && !seen[i])
{
add_subtree(i, node, root);
}
}
};
function<void(int, int, int)> remove_subtree = [&](int node, int parent, int root)
{
exista.erase(exista.find(get_hashup(node, root)));
for (auto i : g[node])
{
if (i != parent && !seen[i])
{
remove_subtree(i, node, root);
}
}
};
bool este = false;
int m = 1;
nodes[1] = node;
function<void(int, int, int)> dfs = [&](int node, int parent, int d)
{
nodes[++m] = node;
int t = 2 * d - k;
int l = k - d;
if (l >= 0 && t >= 0)
{
if (l == 0)
{
if (hashup[node] == hashdown[node])
{
este = true;
}
}
else
{
int qui = nodes[m - l];
if (hashup[qui] == hashdown[qui] && exista.count(get_hashup(node, nodes[m - l + 1])))
{
este = true;
}
}
}
for (auto i : g[node])
{
if (i != parent && !seen[i])
{
dfs(i, node, d + 1);
}
}
--m;
};
for (auto i : g[node])
{
if (!seen[i])
{
add_subtree(i, node, i);
}
}
for (auto i : g[node])
{
if (!seen[i])
{
remove_subtree(i, node, i);
dfs(i, node, 2);
add_subtree(i, node, i);
}
}
return este;
};
return check(k);
}
int decomp(int node, int k)
{
dfs_size(node, 0);
node = find_centroid(node, 0, sz[node]);
int ans = solve(node, k);
seen[node] = true;
for (auto i : g[node])
{
if (!seen[i])
{
ans = max(ans, decomp(i, k));
}
}
return ans;
}
void reinit()
{
seen = vector<bool>(n + 1);
}
};
int32_t main()
{
cin.tie(nullptr)->sync_with_stdio(false);
int n;
cin >> n;
lampice g;
g.init(n);
for (int i = 1; i <= n; ++i)
{
char x;
cin >> x;
g.set_color(i, x);
}
for (int i = 1; i < n; ++i)
{
int u, v;
cin >> u >> v;
g.add_edge(u, v);
}
int ans = 1;
int st = 1, dr = n / 2;
while (st <= dr)
{
int mid = (st + dr) / 2;
if (g.decomp(1, 2 * mid))
{
ans = max(ans, 2 * mid);
st = mid + 1;
}
else
{
dr = mid - 1;
}
g.reinit();
}
st = max(1, ans / 2), dr = n / 2;
while (st <= dr)
{
int mid = (st + dr) / 2;
if (g.decomp(1, 2 * mid + 1))
{
ans = max(ans, 2 * mid + 1);
st = mid + 1;
}
else
{
dr = mid - 1;
}
g.reinit();
}
cout << ans;
}
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
8 ms |
340 KB |
Output is correct |
2 |
Correct |
25 ms |
436 KB |
Output is correct |
3 |
Correct |
85 ms |
544 KB |
Output is correct |
4 |
Correct |
137 ms |
596 KB |
Output is correct |
5 |
Correct |
0 ms |
212 KB |
Output is correct |
6 |
Correct |
0 ms |
212 KB |
Output is correct |
7 |
Correct |
0 ms |
212 KB |
Output is correct |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Execution timed out |
5063 ms |
8768 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Execution timed out |
5051 ms |
7768 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
8 ms |
340 KB |
Output is correct |
2 |
Correct |
25 ms |
436 KB |
Output is correct |
3 |
Correct |
85 ms |
544 KB |
Output is correct |
4 |
Correct |
137 ms |
596 KB |
Output is correct |
5 |
Correct |
0 ms |
212 KB |
Output is correct |
6 |
Correct |
0 ms |
212 KB |
Output is correct |
7 |
Correct |
0 ms |
212 KB |
Output is correct |
8 |
Execution timed out |
5063 ms |
8768 KB |
Time limit exceeded |
9 |
Halted |
0 ms |
0 KB |
- |