#include<bits/stdc++.h>
using namespace std;
const int MXN = 5e5+2;
int n, k, dsu[MXN], h[MXN], deg[MXN], sta[MXN], par[MXN];
vector<int> g[MXN];
void dfs(int v, int p=-1) {
for(int u : g[v])
if(u!=p)
par[u] = v,
h[u] = h[v]+1,
dfs(u, v);
}
int find(int v) { return v==dsu[v] ? v : dsu[v]=find(dsu[v]); }
void merge(int u, int v) {
u = find(u), v = find(v);
while(u!=v) {
if(h[u]<h[v]) swap(u, v);
dsu[u] = par[u];
u = find(u);
}
}
int32_t main() {
cin.tie(0); cout.tie(0); ios_base::sync_with_stdio(0);
cin >> n >> k;
for(int i=0,u,v; i<n-1; i++) {
cin >> u >> v; u--; v--;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(0);
fill(sta, sta+k, -1);
iota(dsu, dsu+n, 0);
for(int i=0, s; i<n; i++) {
cin >> s; s--;
if(sta[s]!=-1) merge(sta[s], i);
else sta[s] = i;
}
for(int i=0, ii, jj; i<n; i++)
for(int j : g[i])
if(i<j) {
ii = find(i), jj = find(j);
if(ii!=jj) deg[ii]++, deg[jj]++;
}
int ans=0;
for(int i=0; i<n; i++) ans += i==dsu[i] && deg[i]==1;
cout << ((ans+1)>>1) << '\n';
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |