Submission #223907

#TimeUsernameProblemLanguageResultExecution timeMemory
223907medkCapital City (JOI20_capital_city)C++14
11 / 100
462 ms147968 KiB
#include <bits/stdc++.h> #define ll long long #define pb push_back #define x first #define y second #define all(u) u.begin(),u.end() #define sz(u) (int)(u.size()) #define INF (int)(1e9) using namespace std; const int MAXN=2e5+5; int n,k,ans=INF; vector<vector<int>> tr(MAXN),grp(MAXN),cd(MAXN); int col[MAXN],subsz[MAXN],parcd[MAXN],par[MAXN],dpth[MAXN],in[MAXN],dpthcd[MAXN]; bool blocked[MAXN]; vector<vector<int>> lca(MAXN,vector<int>(20)); int root=-1; vector<int> eul; void calc_sz(int u, int pr) { subsz[u]=1; for(int v:tr[u]) { if(v==pr || blocked[v]) continue; calc_sz(v,u); subsz[u]+=subsz[v]; } } int find_centroid(int u, int pr, int tar) { for(int v:tr[u]) { if(v==pr || blocked[v]) continue; if(subsz[v]>tar) return find_centroid(v,u,tar); } return u; } void decomp(int u, int pr, int d=0) { calc_sz(u,pr); int centroid=find_centroid(u,pr,subsz[u]/2); blocked[centroid]=1; parcd[centroid]=pr; dpthcd[centroid]=d; if(pr==-1) root=centroid; else { cd[centroid].pb(pr); cd[pr].pb(centroid); } for(int v:tr[centroid]) { if(blocked[v]) continue; decomp(v,centroid,d+1); } } void makeul(int u, int pr, int d) { in[u]=sz(eul); eul.pb(u); dpth[u]=d; par[u]=pr; d++; for(int v:tr[u]) { if(pr==v) continue; makeul(v,u,d); eul.pb(u); } } int getlca(int u, int v) { int l=in[u], r=in[v]; if(l>r) swap(l,r); int lg=log2(r-l+1); if(dpth[lca[l][lg]]<=dpth[lca[r-(1<<lg)+1][lg]]) return lca[l][lg]; return lca[r-(1<<lg)+1][lg]; } int lcacd(int u, int v) { if(dpthcd[u]==dpthcd[v]) { if(u==v) return u; else return lcacd(parcd[u],parcd[v]); } if(dpthcd[v]<dpthcd[u]) swap(u,v); return lcacd(u,parcd[v]); } void solve(int rt, int pr) { unordered_map<int,int> visver; unordered_map<int,int> rep; int lc=rt, used=0; queue<int> q; q.push(rt); visver[rt]=1; rep[col[rt]]=rt; bool bug=0; while(!q.empty()) { int u=q.front(); q.pop(); visver[u]=2; if(u!=lc) { int mypar=par[u]; if(visver[mypar]==0 && col[mypar]!=col[u] && !rep.count(col[mypar])) { if(lcacd(mypar,rt)!=rt) { bug=1; break; } q.push(mypar); visver[mypar]=1; rep[col[mypar]]=mypar; } } if(rep[col[u]]!=u) continue; used++; for(int v:grp[col[u]]) { if(v==u) continue; int newlc=getlca(lc,v); if(newlc!=lc) { if(visver[lc]==2) { if(lcacd(par[lc],rt)!=rt) { bug=1; break; } if(!rep.count(col[par[lc]])) { q.push(par[lc]); visver[par[lc]]=1; rep[col[par[lc]]]=par[lc]; } } lc=newlc; } if(lcacd(v,rt)!=rt) { bug=1; break; } visver[v]=1; q.push(v); } if(bug) break; } if(!bug) ans=min(ans,used); for(int v:cd[rt]) { if(v==pr) continue; solve(v,rt); } } int main() { //freopen("test.in","r",stdin); ios::sync_with_stdio(0); cin.tie(0); cout.tie(0); cin>>n>>k; for(int i=0;i<n-1;i++) { int u,v; cin>>u>>v; u--,v--; tr[u].pb(v); tr[v].pb(u); } for(int i=0;i<n;i++) { cin>>col[i]; col[i]--; grp[col[i]].pb(i); } decomp(0,-1); makeul(0,-1,0); for(int i=0;i<sz(eul);i++) lca[i][0]=eul[i]; for(int j=1;j<20;j++) for(int i=0;i+(1<<j)<=sz(eul);i++) lca[i][j]=(dpth[lca[i][j-1]]<=dpth[lca[i+(1<<(j-1))][j-1]]?lca[i][j-1]:lca[i+(1<<(j-1))][j-1]); solve(root,-1); cout<<ans-1<<'\n'; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...