#include<iostream>
#include<stack>
#include<map>
#include<vector>
#include<string>
#include<unordered_map>
#include <queue>
#include<cstring>
#include<limits.h>
#include<cmath>
#include<set>
#include<algorithm>
#include<bitset>
using namespace std;
#define ll long long
#define f first
#define endl "\n"
#define s second
#define pii pair<int,int>
#define ppii pair<pii,int>
#define pb push_back
#define all(x) x.begin(),x.end()
#define F(n) for(int i=0;i<n;i++)
#define lb lower_bound
#define p push
#define fastio ios::sync_with_stdio(false);cin.tie(NULL);
using namespace std;
const int mxn=5*1e5,mod=998244353,lg=19;
void setIO(string name) {
ios_base::sync_with_stdio(0); cin.tie(0);
freopen((name+".in").c_str(),"r",stdin);
freopen((name+".out").c_str(),"w",stdout);
}
int ans=0;
int pa[mxn+10],sz[mxn+10],up[mxn+10][lg+2],h[mxn+10];
vector<int>adj[mxn+10];
int find(int u){
if(u==pa[u])return u;
return pa[u]=find(pa[u]);
}
void merg(int u,int v){
int a=find(u),b=find(v);
if(a==b)return;
if(sz[a]>sz[b]){
pa[b]=a;
sz[a]+=sz[b];
return;
}
pa[a]=b;
sz[b]+=sz[a];
}
int st[mxn+10],un[mxn+10];
vector<int>s[mxn+10];
void solve(int cur,int p){
for(auto i:adj[cur]){
if(i==p)continue;
up[i][0]=cur;
h[i]=h[cur]+1;
solve(i,cur);
}
}
int lca(int u,int v){
if(h[u]<h[v])swap(u,v);
int k=h[u]-h[v];
for(int i=0;i<=30;i++)if(k&(1<<i))u=up[u][i];
if(u==v)return u;
for(int i=lg;i>=0;i--)if(up[u][i]!=up[v][i])u=up[u][i],v=up[v][i];
return up[u][0];
}
void solve2(int cur,int p){
int g=un[cur];
for(auto i:adj[cur]){
if(i==p)continue;
solve2(i,cur);
if(un[i]){
un[cur]=lca(un[cur],un[i]);
merg(st[cur],st[i]);
}
}
if(un[cur]==cur)un[cur]=0;
}
int deg[mxn+10];
int32_t main(){
//setIO("pieaters");
fastio
int n,m;cin>>n>>m;
for(int i=0;i<n-1;i++){
int u,v;cin>>u>>v;
adj[u].pb(v);
adj[v].pb(u);
}
h[0]=INT_MAX;
for(int i=1;i<=n;i++)cin>>st[i],s[st[i]].pb(i);
for(int i=1;i<=m;i++)pa[i]=i,sz[i]=1;
solve(1,-1);
for(int i=1;i<=lg;i++)for(int j=1;j<=n;j++)up[j][i]=up[up[j][i-1]][i-1];
for(int i=1;i<=m;i++){
int node=-1;
for(auto j:s[i]){
if(node==-1)node=j;
else node=lca(node,j);
}
for(auto j:s[i])un[j]=node;
}
solve2(1,-1);
int ans=0;
for(int i=1;i<=m;i++)st[i]=find(st[i]);
for(int i=1;i<=n;i++)for(auto j:adj[i])if(st[i]!=st[j])deg[st[j]]++;
for(int i=1;i<=n;i++)if(deg[i]==1)ans++;
cout<<(ans+1)/2;
}
Compilation message
mergers.cpp: In function 'void solve2(int, int)':
mergers.cpp:71:9: warning: unused variable 'g' [-Wunused-variable]
71 | int g=un[cur];
| ^
mergers.cpp: In function 'void setIO(std::string)':
mergers.cpp:31:9: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
31 | freopen((name+".in").c_str(),"r",stdin);
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
mergers.cpp:32:9: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
32 | freopen((name+".out").c_str(),"w",stdout);
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
11 ms |
23764 KB |
Output is correct |
2 |
Incorrect |
11 ms |
23864 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
11 ms |
23764 KB |
Output is correct |
2 |
Incorrect |
11 ms |
23864 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
11 ms |
23764 KB |
Output is correct |
2 |
Incorrect |
11 ms |
23864 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
52 ms |
37356 KB |
Output is correct |
2 |
Correct |
63 ms |
41032 KB |
Output is correct |
3 |
Incorrect |
14 ms |
24276 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
11 ms |
23764 KB |
Output is correct |
2 |
Incorrect |
11 ms |
23864 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |