#include "bits/stdc++.h"
using namespace std;
#ifdef duc_debug
#include "bits/debug.h"
#else
#define debug(...)
#endif
const int maxn = 2e5 + 5;
int n, c[maxn], m;
vector<int> g[maxn];
int d1, d2, h[maxn];
int tin[maxn], tout[maxn], timer;
vector<int> st;
int res[maxn];
int f[maxn];
int cnt[maxn], cur;
vector<int> que[maxn];
void add(int x) {
st.push_back(x);
if (!cnt[c[x]]) ++cur;
++cnt[c[x]];
}
void del() {
int x = st.back();
for (auto v:que[x]) {
res[v] = max(res[v], cur);
}
que[x].clear();
st.pop_back();
if (cnt[c[x]] == 1) --cur;
--cnt[c[x]];
}
void dfs_diam(int u, int prev) {
for (auto v:g[u]) {
if (v == prev) continue;
h[v] = h[u] + 1;
dfs_diam(v, u);
}
}
void get_diam() {
dfs_diam(1, 0);
d1 = max_element(h + 1, h + n + 1) - h;
for (int i = 1; i <= n; ++i) {
if (!d1 || h[d1] < h[i]) d1 = i;
}
h[d1] = 0;
dfs_diam(d1, 0);
for (int i = 1; i <= n; ++i) {
if (!d2 || h[d2] < h[i]) d2 = i;
}
}
void pre_dfs(int u, int prev) {
tin[u] = ++timer;
f[u] = 0;
for (auto v:g[u]) {
if (v == prev) continue;
h[v] = h[u] + 1;
pre_dfs(v, u);
f[u] = max(f[u], f[v] + 1);
}
tout[u] = timer;
}
int find_pref(int k) {
int l = 0, r = (int)st.size() - 1;
int pos = -1;
while (l <= r) {
int mid = (l + r) >> 1;
if (h[st[mid]] < k) {
pos = mid;
l = mid + 1;
} else {
r = mid - 1;
}
}
return pos;
}
void dfs(int u, int prev) {
int ft = -1, se = -1;
for (auto v:g[u]) {
if (v == prev) continue;
if (ft < f[v]) {
se = ft;
ft = f[v];
} else if (se < f[v]) {
se = f[v];
}
}
// debug(u, cur, st);
sort(g[u].begin(), g[u].end(), [](const int &x, const int &y) -> bool {
return f[x] > f[y];
});
for (auto v:g[u]) {
if (v == prev) continue;
long long mx = (ft == f[v] ? se : ft);
while (!st.empty() and h[u] - h[st.back()] <= mx + 1) {
del();
}
bool unique = ((int)g[v].size() <= 2);
if (unique) {
add(u);
}
int p = find_pref(h[v] - f[v]);
if (p != -1) {
// debug(v, st, p, st[p]);
que[st[p]].push_back(v);
}
dfs(v, u);
if (unique) {
if (!st.empty() and st.back() == u) {
del();
}
}
}
}
void calc(int root) {
h[root] = 0;
timer = 0;
debug(root);
pre_dfs(root, 0);
dfs(root, 0);
}
void solve() {
cin >> n >> m;
for (int i = 1; i < n; ++i) {
int u, v; cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
for (int i = 1; i <= n; ++i) {
cin >> c[i];
}
get_diam();
calc(d1);
calc(d2);
for (int i = 1; i <= n; ++i) {
cout << res[i] << '\n';
}
}
signed main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
solve();
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |