Submission #787287

#TimeUsernameProblemLanguageResultExecution timeMemory
787287winter0101Capital City (JOI20_capital_city)C++14
100 / 100
784 ms68988 KiB
#include<bits/stdc++.h> using namespace std; #define all(fl) fl.begin(),fl.end() #define pb push_back #define fi first #define se second #define for1(i,j,k) for(int i=j;i<=k;i++) #define for2(i,j,k) for(int i=j;i>=k;i--) #define for3(i,j,k,l) for(int i=j;i<=k;i+=l) #define lb lower_bound #define ub upper_bound #define sz(a) (int)a.size() #define pii pair<int,int> #define pli pair<long long,int> #define gcd __gcd #define lcm(x,y) x*y/__gcd(x,y) #define pil pair<int,long long> const int maxn=2e5+9; vector<int>a[maxn]; int sub[maxn]; int st[maxn][18]; int dep[maxn]; int col[maxn]; int c[maxn]; vector<int>ver[maxn]; void dfs(int u,int par){ sub[u]=1; for (auto &v:a[u]){ if (v==par)continue; st[v][0]=u; for1(i,1,17)st[v][i]=st[st[v][i-1]][i-1]; dep[v]=dep[u]+1; dfs(v,u); sub[u]+=sub[v]; if (a[u][0]==par||sub[a[u][0]]<sub[v]){ swap(v,a[u][0]); } } } int tme=0; int h[maxn],pos[maxn],rev[maxn]; void hld(int u,int par,int head){ h[u]=head; pos[u]=++tme; rev[pos[u]]=u; if (a[u][0]!=par){ hld(a[u][0],u,head); } for (auto v:a[u]){ if (v==par||v==a[u][0])continue; hld(v,u,v); } } int lca(int u,int v){ if (u==0)return v; if (v==0)return u; if (dep[u]<dep[v])swap(u,v); int k=dep[u]-dep[v]; for1(i,0,17){ if (k>>i&1)u=st[u][i]; } if (u==v)return u; for2(i,17,0){ if (!st[u][i]||!st[v][i])continue; if (st[u][i]!=st[v][i]){ u=st[u][i]; v=st[v][i]; } } return st[u][0]; } int f[maxn]; int findset(int u){ if (f[u]<0)return u; return f[u]=findset(f[u]); } void unite(int u,int v){ u=findset(u),v=findset(v); if (u==v)return ; if (f[u]>f[v])swap(u,v); f[u]+=f[v]; col[u]=lca(col[u],col[v]); for (auto v1:ver[v]){ ver[u].pb(v1); } vector<int>().swap(ver[v]); f[v]=u; } int num[maxn],low[maxn],tim=0; int n,k; stack<int>scc; struct ITgetmin{ vector<int>st; void resz(int _n){ st.resize(4*_n+9); } void update(int id,int l,int r,int u,int val){ if (l>u||r<u)return ; if (l==r){ st[id]=val; return; } int mid=(l+r)/2; update(id*2,l,mid,u,val); update(id*2+1,mid+1,r,u,val); st[id]=min(st[id*2],st[id*2+1]); } int get(int id,int l,int r,int u,int v){ if (l>v||r<u||u>v)return n+1; if (u<=l&&r<=v)return st[id]; int mid=(l+r)/2; return min(get(id*2,l,mid,u,v),get(id*2+1,mid+1,r,u,v)); } }; pii combine(const pii &p, const pii &q){ if (p.se>q.se)return p; else return q; } struct IT{ vector<pii>st; void resz(int _n){ st.resize(4*_n+9); } void update(int id,int l,int r,int u,int val){ if (l>u||r<u)return ; if (l==r){ st[id]={l,val}; return; } int mid=(l+r)/2; update(id*2,l,mid,u,val); update(id*2+1,mid+1,r,u,val); st[id]=combine(st[id*2],st[id*2+1]); } pii get(int id,int l,int r,int u,int v){ if (l>v||r<u||u>v)return {0,0}; if (u<=l&&r<=v)return st[id]; int mid=(l+r)/2; return combine(get(id*2,l,mid,u,v),get(id*2+1,mid+1,r,u,v)); } }; ITgetmin st1; IT st2; void dfs(int u){ num[u]=low[u]=++tim; scc.push(u); //cerr<<"DFS "<<u<<'\n'; for (auto v:ver[u]){ st1.update(1,1,n,pos[v],num[u]); st2.update(1,1,n,pos[v],num[u]); } for (auto v:ver[u]){ int v1=v; while (h[v1]!=h[col[u]]){ int l=pos[h[v1]],r=pos[v1]; while (true){ pii tmp=st2.get(1,1,n,l,r); if (tmp.se!=n+1)break; dfs(c[rev[tmp.fi]]); } v1=st[h[v1]][0]; } int l=pos[col[u]],r=pos[v1]; if (l>r)swap(l,r); while (true){ pii tmp=st2.get(1,1,n,l,r); if (tmp.se!=n+1)break; dfs(c[rev[tmp.fi]]); } } for (auto v:ver[u]){ int v1=v; while (h[v1]!=h[col[u]]){ int l=pos[h[v1]],r=pos[v1]; low[u]=min(low[u],st1.get(1,1,n,l,r)); v1=st[h[v1]][0]; } int l=pos[col[u]],r=pos[v1]; if (l>r)swap(l,r); low[u]=min(low[u],st1.get(1,1,n,l,r)); } if (num[u]==low[u]){ while (scc.top()!=u){ int v=scc.top(); scc.pop(); unite(u,v); } scc.pop(); for (auto v:ver[findset(u)]){ st1.update(1,1,n,pos[v],n+1); } } else { for (auto v:ver[u]){ st1.update(1,1,n,pos[v],low[u]); } } } int bit[maxn]; void add(int pos1,int val){ for(;pos1<=n;pos1+=(pos1-(pos1&(pos1-1))))bit[pos1]+=val; } int get(int pos1){ int sum=0; for(;pos1>=1;pos1-=(pos1-(pos1&(pos1-1))))sum+=bit[pos1]; return sum; } int get(int l,int r){ if (l>r)return 0; return get(r)-get(l-1); } signed main(){ ios_base::sync_with_stdio(0); cin.tie(0); //freopen("temp.INP","r",stdin); //freopen("temp.OUT","w",stdout); cin>>n>>k; if (n==1){ cout<<0; return 0; } for1(i,1,n-1){ int u,v; cin>>u>>v; a[u].pb(v); a[v].pb(u); } dfs(1,0); hld(1,0,1); for1(i,1,n){ cin>>c[i]; ver[c[i]].pb(i); col[c[i]]=lca(col[c[i]],i); } st1.resz(n); for1(i,1,n){ st1.update(1,1,n,i,n+1); } st2.resz(n); for1(i,1,n){ st2.update(1,1,n,i,n+1); } for1(i,1,k)f[i]=-1; for1(i,1,k){ if (!num[i]){ dfs(i); } } int ans=k-1; for1(i,1,n)add(i,1); for1(i,1,k){ if (f[i]<0){ for (auto v:ver[i]){ add(pos[v],-1); } int sum=0; for (auto v:ver[i]){ int u=v; while(h[u]!=h[col[i]]){ int l=pos[h[u]],r=pos[u]; sum+=get(l,r); u=st[h[u]][0]; } int l=pos[col[i]],r=pos[u]; if (l>r)swap(l,r); sum+=get(l,r); } if (sum==0){ ans=min(ans,abs(f[i])-1); } for (auto v:ver[i]){ add(pos[v],1); } } } cout<<ans; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...