#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 2e5 + 5, INF = 1e9;
int n, k;
int c[maxn];
vector<int> adj[maxn];
vector<int> mem[maxn];
int ans;
int cnt = 0;
int sz[maxn], vis[maxn], fa[maxn];
vector<int> nodes;
void dfs_sz(int u) {
nodes.push_back(u);
sz[u] = 1, vis[u] = cnt;
for (int v:adj[u]) if (v != fa[u]) {
fa[v] = u;
dfs_sz(v);
sz[u] += sz[v];
}
}
void dfs_fa(int u) {
for (int v:adj[u]) if (v != fa[u]) {
fa[v] = u;
dfs_fa(v);
}
}
bool taken[maxn], coltaken[maxn];
void solve(int U) {
cnt++; nodes.clear();
fa[U] = 0;
dfs_sz(U);
int cent, nn = nodes.size();
for (int u:nodes) {
bool flag = (nn - sz[u] <= nn/2);
for (int v:adj[u]) if (v != fa[u] && sz[v] > nn/2) flag = false;
if (flag) cent = u;
}
for (int u:nodes) taken[u] = coltaken[c[u]] = false;
// for (int u:nodes) cout << u << " "; cout << endl;
// cout << cent << endl;
taken[0] = true;
fa[cent] = 0;
dfs_fa(cent);
queue<int> q;
q.push(c[cent]);
coltaken[c[cent]] = true;
int curans = 0;
while (q.size()) {
int cur = q.front();
curans++;
q.pop();
for (int u:mem[cur]) {
if (vis[u] != cnt) {
curans = INF;
break;
}
while (!taken[u]) {
if (!coltaken[c[u]]) {
coltaken[c[u]] = true;
q.push(c[u]);
}
taken[u] = true;
u = fa[u];
}
}
}
ans = min(curans, ans);
for (int v:adj[cent]) {
adj[v].erase(find(adj[v].begin(), adj[v].end(), cent));
solve(v);
}
}
signed main() {
ios::sync_with_stdio(0); cin.tie(0);
cin >> n >> k;
for (int i=1;i<=n-1;i++) {
int u, v;
cin >> u >> v;
adj[u].push_back(v); adj[v].push_back(u);
}
for (int i=1;i<=n;i++) {
cin >> c[i];
mem[c[i]].push_back(i);
}
ans = n;
solve(1);
cout << ans - 1 << "\n";
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |