Submission #787278

#TimeUsernameProblemLanguageResultExecution timeMemory
787278winter0101수도 (JOI20_capital_city)C++14
0 / 100
645 ms68984 KiB
#include<bits/stdc++.h>
using namespace std;
#define all(fl) fl.begin(),fl.end()
#define pb push_back
#define fi first
#define se second
#define for1(i,j,k) for(int i=j;i<=k;i++)
#define for2(i,j,k) for(int i=j;i>=k;i--)
#define for3(i,j,k,l) for(int i=j;i<=k;i+=l)
#define lb lower_bound
#define ub upper_bound
#define sz(a) (int)a.size()
#define pii pair<int,int>
#define pli pair<long long,int>
#define gcd __gcd
#define lcm(x,y) x*y/__gcd(x,y)
#define pil pair<int,long long>
const int maxn=2e5+9;
vector<int>a[maxn];
int sub[maxn];
int st[maxn][18];
int dep[maxn];
int col[maxn];
int c[maxn];
vector<int>ver[maxn];
void dfs(int u,int par){
sub[u]=1;
for (auto &v:a[u]){
    if (v==par)continue;
    st[v][0]=u;
    for1(i,1,17)st[v][i]=st[st[v][i-1]][i-1];
    dep[v]=dep[u]+1;
    dfs(v,u);
    sub[u]+=sub[v];
    if (a[u][0]==par||sub[a[u][0]]<sub[v]){
        swap(v,a[u][0]);
    }
}
}
int tme=0;
int h[maxn],pos[maxn],rev[maxn];
void hld(int u,int par,int head){
h[u]=head;
pos[u]=++tme;
rev[pos[u]]=u;
if (a[u][0]!=par){
    hld(a[u][0],u,head);
}
for (auto v:a[u]){
    if (v==par||v==a[u][0])continue;
    hld(v,u,v);
}
}
int lca(int u,int v){
if (u==0)return v;
if (v==0)return u;
if (dep[u]<dep[v])swap(u,v);
int k=dep[u]-dep[v];
for1(i,0,17){
if (k>>i&1)u=st[u][i];
}
if (u==v)return u;
for2(i,17,0){
if (!st[u][i]||!st[v][i])continue;
if (st[u][i]!=st[v][i]){
    u=st[u][i];
    v=st[v][i];
}
}
return st[u][0];
}
int f[maxn];
int findset(int u){
if (f[u]<0)return u;
return f[u]=findset(f[u]);
}
void unite(int u,int v){
u=findset(u),v=findset(v);
if (u==v)return ;
if (f[u]>f[v])swap(u,v);
f[u]+=f[v];
col[u]=lca(col[u],col[v]);
for (auto v1:ver[v]){
    ver[u].pb(v1);
}
vector<int>().swap(ver[v]);
f[v]=u;
}
int num[maxn],low[maxn],tim=0;
int n,k;
stack<int>scc;
struct ITgetmin{
vector<int>st;
void resz(int _n){
st.resize(4*_n+9);
}
void update(int id,int l,int r,int u,int val){
if (l>u||r<u)return ;
if (l==r){
    st[id]=val;
    return;
}
int mid=(l+r)/2;
update(id*2,l,mid,u,val);
update(id*2+1,mid+1,r,u,val);
st[id]=min(st[id*2],st[id*2+1]);
}
int get(int id,int l,int r,int u,int v){
if (l>v||r<u||u>v)return n+1;
if (u<=l&&r<=v)return st[id];
int mid=(l+r)/2;
return min(get(id*2,l,mid,u,v),get(id*2+1,mid+1,r,u,v));
}
};
pii combine(const pii &p, const pii &q){
if (p.se>q.se)return p;
else return q;
}
struct IT{
vector<pii>st;
void resz(int _n){
st.resize(4*_n+9);
}
void update(int id,int l,int r,int u,int val){
if (l>u||r<u)return ;
if (l==r){
    st[id]={l,val};
    return;
}
int mid=(l+r)/2;
update(id*2,l,mid,u,val);
update(id*2+1,mid+1,r,u,val);
st[id]=combine(st[id*2],st[id*2+1]);
}
pii get(int id,int l,int r,int u,int v){
if (l>v||r<u||u>v)return {0,0};
if (u<=l&&r<=v)return st[id];
int mid=(l+r)/2;
return combine(get(id*2,l,mid,u,v),get(id*2+1,mid+1,r,u,v));
}
};
ITgetmin st1;
IT st2;
void dfs(int u){
num[u]=low[u]=++tim;
scc.push(u);
//cerr<<"DFS "<<u<<'\n';
for (auto v:ver[u]){
    st1.update(1,1,n,pos[v],num[u]);
    st2.update(1,1,n,pos[v],num[u]);
}
for (auto v:ver[u]){
    int v1=v;
    while (h[v1]!=h[col[u]]){
        int l=pos[h[v1]],r=pos[v1];
        while (true){
            pii tmp=st2.get(1,1,n,l,r);
            if (tmp.se!=n+1)break;
            dfs(c[rev[tmp.fi]]);
        }
        v1=st[h[v1]][0];
    }
    int l=pos[col[u]],r=pos[v1];
    if (l>r)swap(l,r);
     while (true){
        pii tmp=st2.get(1,1,n,l,r);
        if (tmp.se!=n+1)break;
        dfs(c[rev[tmp.fi]]);
    }
}
for (auto v:ver[u]){
    int v1=v;
    while (h[v1]!=h[col[u]]){
        int l=pos[h[v1]],r=pos[v1];
        low[u]=min(low[u],st1.get(1,1,n,l,r));
        v1=st[h[v1]][0];
    }
    int l=pos[col[u]],r=pos[v1];
    if (l>r)swap(l,r);
    low[u]=min(low[u],st1.get(1,1,n,l,r));
}
if (num[u]==low[u]){
    while (scc.top()!=u){
        int v=scc.top();
        scc.pop();
        unite(u,v);
    }
    scc.pop();
    for (auto v:ver[findset(u)]){
        st1.update(1,1,n,pos[v],n+1);
    }
}
}
int bit[maxn];
void add(int pos1,int val){
for(;pos1<=n;pos1+=(pos1-(pos1&(pos1-1))))bit[pos1]+=val;
}
int get(int pos1){
int sum=0;
for(;pos1>=1;pos1-=(pos1-(pos1&(pos1-1))))sum+=bit[pos1];
return sum;
}
int get(int l,int r){
if (l>r)return 0;
return get(r)-get(l-1);
}
signed main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    //freopen("temp.INP","r",stdin);
    //freopen("temp.OUT","w",stdout);
    cin>>n>>k;
    for1(i,1,n-1){
    int u,v;
    cin>>u>>v;
    a[u].pb(v);
    a[v].pb(u);
    }
    dfs(1,0);
    hld(1,0,1);
    for1(i,1,n){
    cin>>c[i];
    ver[c[i]].pb(i);
    col[c[i]]=lca(col[c[i]],i);
    }
    st1.resz(n);
    for1(i,1,n){
    st1.update(1,1,n,i,n+1);
    }
    st2.resz(n);
    for1(i,1,n){
    st2.update(1,1,n,i,n+1);
    }
    for1(i,1,k)f[i]=-1;
    for1(i,1,k){
    if (!num[i]){
        dfs(i);
    }
    }
    int ans=k-1;
    for1(i,1,n)add(i,1);
    for1(i,1,k){
    if (f[i]<0){
     for (auto v:ver[i]){
        add(pos[v],-1);
     }
     int sum=0;
     for (auto v:ver[i]){
        int u=v;
        while(h[u]!=h[col[i]]){
            int l=pos[h[u]],r=pos[u];
            sum+=get(l,r);
            u=st[h[u]][0];
        }
        int l=pos[col[i]],r=pos[u];
        if (l>r)swap(l,r);
        sum+=get(l,r);
     }
     if (sum==0){
        ans=min(ans,abs(f[i])-1);
     }
     for (auto v:ver[i]){
        add(pos[v],1);
     }
    }
    }
    cout<<ans;

}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...