Submission #658164

#TimeUsernameProblemLanguageResultExecution timeMemory
658164azberjibiouCapital City (JOI20_capital_city)C++17
100 / 100
768 ms106472 KiB
#include <bits/stdc++.h>
using namespace std;
const int mxN=200005;
const int mxK=25;
#define pii pair<int, int>
#define pll pair<ll, ll>
#define fi first
#define se second
#define all(x) x.begin(), x.end()
typedef long long ll;
const ll MOD=1000000007;
const ll INF=1e18;

int N, K;
vector <int> v[mxN];
vector <int> ct[mxN];
int C[mxN];
int dep[mxN], par[mxN], sps[mxN][mxK];
int root[mxN];
int in[mxN], out[mxN], iidx;
set <pii> s;
struct cmp1{
     bool operator()(const int a, const int b)const {
        return in[a]<in[b];
     }
};
set <int, cmp1> et[mxN];
set <int> np[mxN];
int cnt[mxN];
int ans=mxN;
void dfs1(int now, int pre=-1)
{
    in[now]=++iidx;
    for(int i=1;i<=19;i++)  sps[now][i]=sps[sps[now][i-1]][i-1];
    for(int nxt : v[now])   if(nxt!=pre)
    {
        dep[nxt]=dep[now]+1;
        sps[nxt][0]=par[nxt]=now;
        dfs1(nxt, now);
    }
    out[now]=iidx;
}
int lca(int a, int b)
{
    if(dep[a]<dep[b])   swap(a, b);
    for(int i=19;i>=0;i--)
    {
        if(dep[a]-(1<<i)>=dep[b])   a=sps[a][i];
    }
    if(a==b)    return a;

    for(int i=19;i>=0;i--)
    {
        if(sps[a][i]!=sps[b][i])    a=sps[a][i], b=sps[b][i];
    }
    return sps[a][0];
}
int uf1[mxN], uf2[mxN];
void init1(){for(int i=1;i<=N;i++)  uf1[i]=i;}
void init2(){for(int i=1;i<=N;i++)  uf2[i]=i;}
int findpar1(int a){return uf1[a]==a ? a : uf1[a]=findpar1(uf1[a]);}
int findpar2(int a)
{
    a=findpar1(a);
    return uf2[a]==a ? a : uf2[a]=findpar2(uf2[a]);
}
void onion(int c1, int c2)
{
    uf1[c2]=c1;
    if(et[c1].size()<et[c2].size())
    {
        swap(np[c1], np[c2]);
        swap(et[c1], et[c2]);
    }
    for(int e : et[c2])
    {
        C[e]=c1;
        if(np[c1].find(e)!=np[c1].end())    np[c1].erase(e);
    }
    for(int e : np[c2])
    {
        if(et[c1].find(e)==et[c1].end())    np[c1].insert(e);
    }
    for(int e : et[c2]) et[c1].insert(e);
    et[c2].clear();
    np[c2].clear();
    cnt[c1]+=cnt[c2];
    root[c1]=lca(root[c1], root[c2]);
    s.insert(pii(et[c1].size(), c1));
}
int main()
{
    cin.tie(0);
    ios::sync_with_stdio(false);
    dep[0]=-1;
    cin >> N >> K;
    for(int i=1;i<N;i++)
    {
        int a, b;
        cin >> a >> b;
        v[a].push_back(b);
        v[b].push_back(a);
    }
    for(int i=1;i<=N;i++)   cin >> C[i], ct[C[i]].push_back(i);
    for(int i=1;i<=K;i++)   if(ct[i].size()==1)
    {
        cout << 0;
        return  0;
    }
    dfs1(1);
    for(int i=1;i<=K;i++)
    {
        root[i]=ct[i][0];
        for(int ele : ct[i])    root[i]=lca(root[i], ele);
    }
    for(int i=1;i<=K;i++)   s.insert(pii(ct[i].size(), i));
    for(int i=1;i<=K;i++)   for(int ele : ct[i])    et[i].insert(ele);
    for(int i=1;i<=K;i++)   for(int ele : ct[i])    if(C[par[ele]]!=i)   np[i].insert(par[ele]);
    for(int i=1;i<=K;i++)   cnt[i]=1;
    init1();
    init2();
    while(s.size())
    {
        int c1=s.begin()->se;
        s.erase(s.begin());
        if(np[c1].size()==1 && *np[c1].begin()==par[root[c1]])
        {
            ans=min(ans, cnt[c1]);
            continue;
        }
        auto it=np[c1].begin();
        int x=*it;
        if(x==par[root[c1]])    x=*(++it);
        int c2=C[x];
        if(findpar2(c2)==c1) onion(c1, c2);
        else    uf2[c1]=c2;
    }
    cout << ans-1;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...