제출 #236812

#제출 시각아이디문제언어결과실행 시간메모리
236812MvCMergers (JOI19_mergers)C++11
100 / 100
1103 ms127680 KiB
#pragma GCC optimize("O3")
#pragma GCC optimize("unroll-loops")
#include <bits/stdc++.h>
#define rc(x) return cout<<x<<endl,0
#define pb push_back
#define mkp make_pair
#define in insert
#define er erase
#define fd find
#define fr first
#define sc second
#define all(x) x.begin(),x.end()
typedef long long ll;
typedef long double ld;
const ll INF=0x3f3f3f3f3f3f3f3f;
const ll llinf=(1LL<<62);
const int inf=(1<<30);
const int nmax=5e5+50;
const ll mod=1e9+7;
using namespace std;
int x,y,i,n,up[nmax][20],tin[nmax],tout[nmax],tt,lc[nmax],sz[nmax],pr[nmax],deg[nmax],rs,j,lvl[nmax],c,k;
vector<int>a[nmax],vc[nmax];
void dfs(int x,int p)
{
    tin[x]=++tt;
    up[x][0]=p;
    lvl[x]=lvl[p]+1;
    for(int i=1;i<20;i++)up[x][i]=up[up[x][i-1]][i-1];
    for(int i=0;i<(int)a[x].size();i++)if(a[x][i]!=p)dfs(a[x][i],x);
    tout[x]=++tt;
}
int anc(int x,int y)
{
    return tin[x]<=tin[y] && tout[x]>=tout[y];
}
int lca(int x,int y)
{
    if(anc(x,y))return x;
    if(anc(y,x))return y;
    for(int i=19;i>=0;i--)if(!anc(up[x][i],y))x=up[x][i];
    return up[x][0];
}
int fnd(int x)
{
	if(pr[x]==x)return x;
	return pr[x]=fnd(pr[x]);
}
void uni(int x,int y)
{
	x=fnd(x),y=fnd(y);
	if(x==y)return;
	if(sz[x]<sz[y])swap(x,y);
	pr[y]=x;
	sz[x]+=sz[y];
}
void bld(int x,int p)
{
    for(int i=0;i<(int)a[x].size();i++)
    {
		int y=a[x][i];
		if(y==p)continue;
		bld(y,x);
	}
	lc[p]=max(lc[p],lc[x]-1);
	if(lc[x])uni(p,x);
}
int main()
{
	//freopen("sol.in","r",stdin);
	//freopen("sol.out","w",stdout);
	//mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
	ios_base::sync_with_stdio(false);cin.tie(0);cerr.tie(0);cout.tie(0);
	cin>>n>>k;
	for(i=1;i<n;i++)
	{
		cin>>x>>y;
		a[x].pb(y);
		a[y].pb(x);
	}
	dfs(1,1);
	for(i=1;i<=n;i++)
	{
		cin>>c;
		pr[i]=i;
		sz[i]=1;
		vc[c].pb(i);
	}
	for(i=1;i<=k;i++)
	{
		if(vc[i].empty())continue;
		x=vc[i][0];
		for(j=1;j<(int)vc[i].size();j++)x=lca(x,vc[i][j]);
		for(j=0;j<(int)vc[i].size();j++)lc[vc[i][j]]=lvl[vc[i][j]]-lvl[x];
	}
	bld(1,1);
	for(i=2;i<=n;i++)
	{
		if(lc[i])continue;
		deg[fnd(i)]++;
		deg[fnd(up[i][0])]++;
	}
	for(i=1;i<=n;i++)if(deg[i]==1)rs++;
	cout<<(rs+1)/2<<endl;
	return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...