This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
#define ll long long
#define pb push_back
#define x first
#define y second
#define all(u) u.begin(),u.end()
#define sz(u) (int)(u.size())
#define INF (int)(1e9)
using namespace std;
const int MAXN=2e5+5;
int n,k,ans=INF;
vector<vector<int>> tr(MAXN),grp(MAXN),cd(MAXN);
int col[MAXN],subsz[MAXN],parcd[MAXN],par[MAXN],dpth[MAXN],in[MAXN],dpthcd[MAXN];
bool blocked[MAXN];
int lca[MAXN][20];
int root=-1;
vector<int> eul;
void calc_sz(int u, int pr)
{
subsz[u]=1;
for(int v:tr[u])
{
if(v==pr || blocked[v]) continue;
calc_sz(v,u);
subsz[u]+=subsz[v];
}
}
int find_centroid(int u, int pr, int tar)
{
for(int v:tr[u])
{
if(v==pr || blocked[v]) continue;
if(subsz[v]>tar) return find_centroid(v,u,tar);
}
return u;
}
void decomp(int u, int pr, int d=0)
{
calc_sz(u,pr);
int centroid=find_centroid(u,pr,subsz[u]/2);
blocked[centroid]=1;
parcd[centroid]=pr;
dpthcd[centroid]=d;
if(pr==-1) root=centroid;
else
{
cd[centroid].pb(pr);
cd[pr].pb(centroid);
}
for(int v:tr[centroid])
{
if(blocked[v]) continue;
decomp(v,centroid,d+1);
}
}
void makeul(int u, int pr, int d)
{
in[u]=sz(eul);
eul.pb(u);
dpth[u]=d;
par[u]=pr;
d++;
for(int v:tr[u])
{
if(pr==v) continue;
makeul(v,u,d);
eul.pb(u);
}
}
int getlca(int u, int v)
{
int l=in[u], r=in[v];
if(l>r) swap(l,r);
int lg=log2(r-l+1);
if(dpth[lca[l][lg]]<=dpth[lca[r-(1<<lg)+1][lg]]) return lca[l][lg];
return lca[r-(1<<lg)+1][lg];
}
int lcacd(int u, int v)
{
if(dpthcd[u]==dpthcd[v])
{
if(u==v) return u;
else return lcacd(parcd[u],parcd[v]);
}
if(dpthcd[v]<dpthcd[u]) swap(u,v);
return lcacd(u,parcd[v]);
}
void solve(int rt, int pr)
{
unordered_map<int,int> visver;
unordered_map<int,int> rep;
int lc=rt, used=0;
queue<int> q;
q.push(rt);
visver[rt]=1;
rep[col[rt]]=rt;
bool bug=0;
while(!q.empty())
{
int u=q.front(); q.pop();
visver[u]=2;
if(u!=lc)
{
int mypar=par[u];
if(visver[mypar]==0 && col[mypar]!=col[u] && !rep.count(col[mypar]))
{
if(lcacd(mypar,rt)!=rt)
{
bug=1;
break;
}
q.push(mypar);
visver[mypar]=1;
rep[col[mypar]]=mypar;
}
}
if(rep[col[u]]!=u) continue;
used++;
for(int v:grp[col[u]])
{
if(v==u) continue;
int newlc=getlca(lc,v);
if(newlc!=lc)
{
if(visver[lc]==2)
{
if(lcacd(par[lc],rt)!=rt)
{
bug=1;
break;
}
if(!rep.count(col[par[lc]]))
{
q.push(par[lc]);
visver[par[lc]]=1;
rep[col[par[lc]]]=par[lc];
}
}
lc=newlc;
}
if(lcacd(v,rt)!=rt)
{
bug=1;
break;
}
visver[v]=1;
q.push(v);
}
if(bug) break;
}
if(!bug) ans=min(ans,used);
for(int v:cd[rt])
{
if(v==pr) continue;
solve(v,rt);
}
}
int main()
{
//freopen("test.in","r",stdin);
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
cin>>n>>k;
for(int i=0;i<n-1;i++)
{
int u,v; cin>>u>>v;
u--,v--;
tr[u].pb(v);
tr[v].pb(u);
}
for(int i=0;i<n;i++)
{
cin>>col[i];
col[i]--;
grp[col[i]].pb(i);
}
decomp(0,-1);
makeul(0,-1,0);
for(int i=0;i<sz(eul);i++) lca[i][0]=eul[i];
for(int j=1;j<20;j++)
for(int i=0;i+(1<<j)<=sz(eul);i++)
lca[i][j]=(dpth[lca[i][j-1]]<=dpth[lca[i+(1<<(j-1))][j-1]]?lca[i][j-1]:lca[i+(1<<(j-1))][j-1]);
solve(root,-1);
cout<<ans-1<<'\n';
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |