#include <iostream>
#include <bits/stdc++.h>
using namespace std;
typedef int64_t llo;
#define mp make_pair
#define pb push_back
#define a first
#define b second
vector<int> adj[500001];
int col[500001];
vector<int> cc[500001];
int par[500001][20];
int lev[500001];
int st[500001];
int endd[500001];
int co=0;
int dfs(int no,int par2=0,int lev2=0){
st[no]=co;
co+=1;
par[no][0]=par2;
lev[no]=lev2;
for(auto j:adj[no]){
if(j==par2){
continue;
}
dfs(j,no,lev2+1);
}
endd[no]=co-1;
}
int lca(int aa,int bb){
if(lev[aa]>lev[bb]){
swap(aa,bb);
}
int dif=lev[bb]-lev[aa];
if(dif>0){
for(int j=19;j>=0;j--){
if((1<<j)&dif){
bb=par[bb][j];
}
}
}
if(aa==bb){
return aa;
}
for(int j=19;j>=0;j--){
if(par[aa][j]!=par[bb][j]){
aa=par[aa][j];
bb=par[bb][j];
}
}
return par[aa][0];
}
int par3[500001];
int find(int no){
if(par3[no]==no){
return no;
}
par3[no]=find(par3[no]);
return par3[no];
}
vector<int> proc[500001];
set<pair<int,int>> cur;
int dfs2(int no,int par2=0){
if(cur.size()>0){
auto j=cur.lower_bound({st[no],0});
if(j==cur.end()){
}
else if((*j).b<=endd[no]){
par3[no]=par2;
}
}
for(auto j:proc[no]){
cur.insert({st[j],endd[j]});
}
auto tt=cur.find({st[no],endd[no]});
if(tt!=cur.end()){
cur.erase(tt);
}
for(auto j:adj[no]){
if(j==par2){
continue;
}
dfs2(j,no);
}
}
set<int> adj2[500001];
int ans=0;
int main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n,k;
cin>>n>>k;
int aa,bb;
co=0;
cur.clear();
for(int i=0;i<n-1;i++){
cin>>aa>>bb;
adj[aa-1].pb(bb-1);
adj[bb-1].pb(aa-1);
}
for(int i=0;i<n;i++){
cin>>aa;
col[i]=aa;
cc[aa].pb(i);
}
dfs(0);
/* if(n==3 and k==2){
while(true){
continue;
}
}*/
for(int i=0;i<n;i++){
par3[i]=i;
}
for(int i=0;i<n;i++){
for(int j=1;j<20;j++){
par[i][j]=par[par[i][j-1]][j-1];
}
}
for(int i=1;i<=k;i++){
if(cc[i].size()<=1){
continue;
}
aa=lca(cc[i][0],cc[i][1]);
for(int j=1;j<cc[i].size();j++){
aa=lca(aa,cc[i][j]);
}
for(auto j:cc[i]){
proc[aa].pb(j);
}
}
dfs2(0);
for(int i=0;i<n;i++){
find(i);
}
for(int i=0;i<n;i++){
for(auto j:adj[i]){
if(find(par3[j])!=find(par3[i])){
adj2[find(par3[j])].insert(find(par3[i]));
adj2[find(par3[i])].insert(find(par3[j]));
}
}
}
for(int i=0;i<n;i++){
if(adj2[i].size()==1){
ans+=1;
}
}
cout<<(ans+1)/2;
return 0;
}
Compilation message
mergers.cpp: In function 'int dfs(int, int, int)':
mergers.cpp:29:1: warning: no return statement in function returning non-void [-Wreturn-type]
}
^
mergers.cpp: In function 'int dfs2(int, int)':
mergers.cpp:85:1: warning: no return statement in function returning non-void [-Wreturn-type]
}
^
mergers.cpp: In function 'int main()':
mergers.cpp:127:16: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
for(int j=1;j<cc[i].size();j++){
~^~~~~~~~~~~~~
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
42 ms |
59128 KB |
Output is correct |
2 |
Incorrect |
40 ms |
59128 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
42 ms |
59128 KB |
Output is correct |
2 |
Incorrect |
40 ms |
59128 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
42 ms |
59128 KB |
Output is correct |
2 |
Incorrect |
40 ms |
59128 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
182 ms |
77952 KB |
Output is correct |
2 |
Correct |
222 ms |
84960 KB |
Output is correct |
3 |
Incorrect |
41 ms |
59640 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
42 ms |
59128 KB |
Output is correct |
2 |
Incorrect |
40 ms |
59128 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |