#include<bits/stdc++.h>
using namespace std;
//#pragma GCC optimize("03,unroll-loops")
//#pragma GCC target("avx2")
//#pragma GCC target("sse4")
#define all(v) v.begin(),v.end()
#define F first
#define S second
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
//#define randi uniform_int_distribution<long long>
#define damoon(v) v.resize(unique(all(v))-v.begin())
//mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
//randi dist(0,10000000000000000);
typedef pair<int,int> pii;
typedef pair<long long,long long> pll;
typedef pair<int,bool> pib;
typedef pair<long long,bool> plb;
typedef pair<int,pii> pip;
typedef pair<pii,int> ppi;
typedef vector<int> veci;
typedef vector<long long> vecl;
typedef vector<bool> vecb;
typedef vector<pii> vecp;
typedef set<int> seti;
typedef set<long long> setl;
typedef set<pii> setp;
typedef map<int,int> mapii;
typedef map<long long,long long> mapll;
typedef map<int,bool> mapib;
typedef map<long long,bool> maplb;
const int inf=1e9,mod=1e9+7,neginf=-1e9,N=2e5+5;
const double PI=acos(-1);
int n,k,ans,cnt,col[N],nxn[N],see[N],mark[N],par[N],siz[N],prt[N];
veci adj[N],vec[N];
void dfs1(int u,int p,veci& t)
{
t.pub(u);
siz[u]=1;
for(int v:adj[u])
{
if(v==p or see[v])
continue;
dfs1(v,u,t);
siz[u]+=siz[v];
}
}
void dfs2(int u,int p)
{
mark[u]=cnt;
for(int v:adj[u])
{
if(v==p or see[v])
continue;
par[v]=u;
dfs2(v,u);
}
}
void Solve(int x)
{
cnt++;
veci t;
dfs1(x,0,t);
int tot=t.size();
int rot=x;
for(int u:t)
if(siz[u]>siz[rot])
rot=u;
rot=x;
for(int u:t)
if(siz[u]>siz[x]/2)
rot=u;
par[rot]=rot;
dfs2(rot,rot);
veci q;
q.pub(rot);
mark[rot]=-cnt;
bool f=0;
for(int i=0;i<q.size();i++)
{
int u=q[i];
int v=u;
while(mark[par[v]]==cnt)
{
v=par[v];
if(mark[v]!=-cnt)
{
mark[v]=-cnt;
q.pub(v);
}
}
if(abs(mark[nxn[u]])!=cnt)
{
f=1;
break;
}
else
{
if(mark[nxn[u]]==cnt)
{
mark[nxn[u]]=-cnt;
q.pub(nxn[u]);
}
}
}
if(!f)
{
int ret=0;
for(int u:q)
{
int c=col[u];
if(prt[c]==0)
{
ret++;
prt[c]=1;
}
}
for(int u:q)
prt[col[u]]=0;
if(ret<ans)
{
ans=ret;
}
}
see[rot]=1;
for(int v:adj[rot])
if(!see[v])
Solve(v);
}
void solve()
{
cin>>n>>k;
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
adj[u].pub(v);
adj[v].pub(u);
}
for(int i=1;i<=n;i++)
{
cin>>col[i];
col[i]--;
if(col[i]<k)
vec[col[i]].pub(i);
}
for(int i=0;i<k;i++)
{
if(vec[i].empty())
continue;
int m=vec[i].size();
for(int j=0;j<m;j++)
{
int u=vec[i][j];
int nxi=(j+1)%m;
nxn[u]=vec[i][nxi];
}
}
ans=n;
cnt=0;
for(int i=1;i<=n;i++)
see[i]=mark[i]=par[i]=siz[i]=0;
for(int i=0;i<=k;i++)
prt[i]=0;
Solve(1);
cout<<ans-1<<"\n";
}
int main()
{
cin.tie(0);
cout.tie(0);
ios_base::sync_with_stdio(false);
//ifstream fin("in.txt");
//ofstream fout("out.txt");
int t=1;
//cin>>t;
while(t--)
{
solve();
}
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |