#include <bits/stdc++.h>
using namespace std;
#define mp make_pair
#define pb push_back
#define fi first
#define se second
const int N = 5e5 + 7;
int n, k, u, v, s[N], level[N], sub[N], par[20][N], cnt = 0, pset[N], ret[N], sz[N], deg[N], ans = 0;
bool vis[N], col[N];
vector <int> edge[N], tp[N], g[N], leaf;
map < pair <int, int>, int > ed;
int read(){
char p;
while((p = getchar())){
if(p < '0' || p > '9') continue;
break;
}
int ret = p - '0';
while((p = getchar())){
if(p >= '0' && p <= '9'){
ret *= 10;
ret += p - '0';
}
else break;
}
return ret;
}
void predfs(int u, int p){
int ch = 0;
for(int v : edge[u]){
if(v == p) continue;
deg[v]++;
par[0][v] = u;
level[v] = level[u] + 1;
predfs(v, u);
ch++;
}
if(ch == 0) leaf.pb(u);
}
int lca(int u, int v){
if(level[u] > level[v]) swap(u, v);
for(int i = 19; i >= 0; i--) if(level[par[i][v]] >= level[u]) v = par[i][v];
for(int i = 19; i >= 0; i--) if(par[i][u] != par[i][v]) u = par[i][u], v = par[i][v];
if(u == v) return u;
return par[0][u];
}
int fset(int x){
if(pset[x] == x) return x;
return pset[x] = fset(pset[x]);
}
void unionset(int u, int v){
u = fset(u), v = fset(v);
if(u == v) return;
pset[u] = v;
sz[v] += sz[u];
sz[u] = 0;
}
void dfs(int u, int p){
int child = 0;
for(int v : g[u]){
if(v == p) continue;
child ++;
dfs(v, u);
}
if(child == 0 || (child == 1 && u == 1)) ans++;
}
int main(){
n = read(), k = read();
for(int i = 1; i < n; i++){
u = read(), v = read();
edge[u].pb(v);
edge[v].pb(u);
}
for(int i = 1; i <= n; i++){
s[i] = read();
tp[s[i]].pb(i);
}
par[0][1] = 1;
level[1] = 1;
predfs(1, 1);
for(int i = 1; i <= 19; i++){
for(int j = 1; j <= n; j++){
par[i][j] = par[i - 1][par[i - 1][j]];
}
}
for(int i = 1; i <= k; i++){
pset[i] = i;
sz[i] = 1;
if(tp[i].empty()) continue;
int cur = tp[i][0];
for(int j = 1; j < (int)tp[i].size(); j++){
cur = lca(cur, tp[i][j]);
}
sub[cur]++;
}
queue < pair <int, int> > q;
for(int x : leaf) q.push(mp(level[x], x));
while(!q.empty()){
int u = q.front().se;
q.pop();
int p = fset(s[u]);
sz[p] -= sub[u];
deg[par[0][u]]--;
if(!deg[par[0][u]]){
q.push(mp(level[par[0][u]], par[0][u]));
}
if(sz[fset(p)] == 0) continue;
unionset(p, s[par[0][u]]);
}
for(int i = 1; i <= k; i++){
if(pset[i] == i) ret[i] = ++cnt;
}
for(int i = 1; i <= k; i++){
if(pset[i] != i) ret[i] = ret[fset(i)];
}
for(int i = 1; i <= n; i++){
int id = ret[s[i]];
for(int j : edge[i]){
int jd = ret[s[j]];
if(id == jd) continue;
if(ed[mp(id, jd)] > 0) continue;
g[id].pb(jd);
g[jd].pb(id);
ed[mp(id, jd)]++;
ed[mp(jd, id)]++;
}
}
dfs(1, 1);
if(ans == 1) printf("0");
else printf("%d", (ans + 1) / 2);
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
35 ms |
35832 KB |
Output is correct |
2 |
Incorrect |
37 ms |
35836 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
35 ms |
35832 KB |
Output is correct |
2 |
Incorrect |
37 ms |
35836 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
35 ms |
35832 KB |
Output is correct |
2 |
Incorrect |
37 ms |
35836 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
299 ms |
49952 KB |
Output is correct |
2 |
Correct |
525 ms |
69604 KB |
Output is correct |
3 |
Incorrect |
42 ms |
36728 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
35 ms |
35832 KB |
Output is correct |
2 |
Incorrect |
37 ms |
35836 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |