#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
vector<int> adj[500500], joi[500500];
int col[500500], cnt[500500], lct[500500], val=0, val2=0;
bool visited[500500], visited2[500500], chk[500500];
struct LCA{
int spa[500500][21], dep[500500];
void dfs(int s, int pa=0){
spa[s][0]=pa;
for (int i=1;i<21 && spa[s][i-1];i++) spa[s][i] = spa[spa[s][i-1]][i-1];
dep[s] = dep[pa]+1;
for (auto &v:adj[s]) if (v!=pa){
dfs(v, s);
}
}
int query(int v, int u){
if (dep[v]>dep[u]) swap(v, u);
for (int i=20;i>=0;i--) if (spa[u][i] && dep[spa[u][i]]>dep[v]) u = spa[u][i];
if (dep[v]!=dep[u]) u = spa[u][0];
if (v!=u){
for (int i=20;i>=0;i--) if (spa[v][i] && spa[v][i]!=spa[u][i]) v = spa[v][i], u = spa[u][i];
v = spa[v][0], u = spa[u][0];
}
return v;
}
}lca;
pair<int, int> dfs1(int s){
pair<int, int> ret;
ret.second = 1;
if (lct[col[s]]==s) ret.first = cnt[col[s]];
visited[s]=1;
//printf("%d %d %d\n", s, ret.first, ret.second);
//printf(" %d %d %d\n", col[s], lct[col[s]], cnt[s]);
for (auto &v:adj[s]) if (!visited[v]){
auto tmp = dfs1(v);
ret.first += tmp.first;
ret.second += tmp.second;
}
if (ret.first==ret.second) chk[s]=1;
return ret;
}
int dfs2(int s, bool test){
int ret=0;
if (chk[s]) ret++;
if (chk[s] && s!=1 && !test) test=1, val2++;
visited2[s] = 1;
for (auto &v:adj[s]) if (!visited2[v]){
ret += dfs2(v, test);
}
if (chk[s] && ret==1 && s!=1) val++;
return ret;
}
int main(){
int n, k;
scanf("%d %d", &n, &k);
for (int i=0;i<n-1;i++){
int a, b;
scanf("%d %d", &a, &b);
adj[a].push_back(b);
adj[b].push_back(a);
}
for (int i=1;i<=n;i++){
scanf("%d", col+i);
cnt[col[i]]++;
joi[col[i]].push_back(i);
}
lca.dfs(1);
for (int i=1;i<=k;i++){
if (cnt[i]==1){
lct[i] = joi[i][0]; continue;
}
lct[i] = joi[i][0];
for (int j=1;j<cnt[i];j++){
lct[i] = lca.query(lct[i], joi[i][j]);
}
}
dfs1(1);
dfs2(1, 0);
/*for (int i=1;i<=k;i++) printf("%d ", lct[i]);
printf("\n");
for (int i=1;i<=n;i++) printf("%d ", chk[i]);
printf("\n");*/
if (val2==1) val++;
//printf("%d\n", val2);
printf("%d\n", (val+1)/2);
return 0;
}
Compilation message
mergers.cpp: In function 'int main()':
mergers.cpp:60:10: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
60 | scanf("%d %d", &n, &k);
| ~~~~~^~~~~~~~~~~~~~~~~
mergers.cpp:63:14: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
63 | scanf("%d %d", &a, &b);
| ~~~~~^~~~~~~~~~~~~~~~~
mergers.cpp:68:14: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
68 | scanf("%d", col+i);
| ~~~~~^~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
16 ms |
23756 KB |
Output is correct |
2 |
Incorrect |
16 ms |
23760 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
16 ms |
23756 KB |
Output is correct |
2 |
Incorrect |
16 ms |
23760 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
16 ms |
23756 KB |
Output is correct |
2 |
Incorrect |
16 ms |
23760 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
93 ms |
37628 KB |
Output is correct |
2 |
Correct |
96 ms |
41116 KB |
Output is correct |
3 |
Incorrect |
21 ms |
24268 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
16 ms |
23756 KB |
Output is correct |
2 |
Incorrect |
16 ms |
23760 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |