Submission #1236730

#TimeUsernameProblemLanguageResultExecution timeMemory
1236730hamanp87Capital City (JOI20_capital_city)C++17
100 / 100
340 ms41856 KiB
#include<bits/stdc++.h>
using namespace std;

//#pragma GCC optimize("03,unroll-loops")
//#pragma GCC target("avx2")
//#pragma GCC target("sse4")

#define all(v) v.begin(),v.end()
#define F first
#define S second
#define pub push_back
#define pob pop_back
#define puf push_front
#define pof pop_front
//#define randi uniform_int_distribution<long long>
#define damoon(v) v.resize(unique(all(v))-v.begin())
//mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
//randi dist(0,10000000000000000);

typedef pair<int,int> pii;
typedef pair<long long,long long> pll;
typedef pair<int,bool> pib;
typedef pair<long long,bool> plb;
typedef pair<int,pii> pip;
typedef pair<pii,int> ppi;
typedef vector<int> veci;
typedef vector<long long> vecl;
typedef vector<bool> vecb;
typedef vector<pii> vecp;
typedef set<int> seti;
typedef set<long long> setl;
typedef set<pii> setp;
typedef map<int,int> mapii;
typedef map<long long,long long> mapll;
typedef map<int,bool> mapib;
typedef map<long long,bool> maplb;

const int inf=1e9,mod=1e9+7,neginf=-1e9,N=2e5+5;
const double PI=acos(-1);
int n,k,ans,cnt,col[N],nxn[N],see[N],mark[N],par[N],siz[N],prt[N];
veci adj[N],vec[N];

void dfs1(int u,int p,veci& t)
{
    t.pub(u);
    siz[u]=1;
    for(int v:adj[u])
    {
        if(v==p or see[v])
            continue;

        dfs1(v,u,t);
        siz[u]+=siz[v];
    }
}

void dfs2(int u,int p)
{
    mark[u]=cnt;
    for(int v:adj[u])
    {
        if(v==p or see[v])
            continue;
        par[v]=u;
        dfs2(v,u);
    }
}

void Solve(int x)
{
    cnt++;
    veci t;
    dfs1(x,0,t);
    int tot=t.size();
    int rot=x;
    for(int u:t)
        if(siz[u]>siz[rot])
            rot=u;
    rot=x;
    for(int u:t)
        if(siz[u]>siz[x]/2)
            rot=u;

    par[rot]=rot;
    dfs2(rot,rot);

    veci q;
    q.pub(rot);
    mark[rot]=-cnt;
    bool f=0;
    for(int i=0;i<q.size();i++)
    {
        int u=q[i];
        int v=u;
        while(mark[par[v]]==cnt)
        {
            v=par[v];
            if(mark[v]!=-cnt)
            {
                mark[v]=-cnt;
                q.pub(v);
            }
        }

        if(abs(mark[nxn[u]])!=cnt)
        {
            f=1;
            break;
        }
        else
        {
            if(mark[nxn[u]]==cnt)
            {
                mark[nxn[u]]=-cnt;
                q.pub(nxn[u]);
            }
        }
    }

    if(!f)
    {
        int ret=0;
        for(int u:q)
        {
            int c=col[u];
            if(prt[c]==0)
            {
                ret++;
                prt[c]=1;
            }
        }
        for(int u:q)
            prt[col[u]]=0;

        if(ret<ans)
        {
            ans=ret;
        }
    }

    see[rot]=1;
    for(int v:adj[rot])
        if(!see[v])
            Solve(v);
}

void solve()
{
    cin>>n>>k;
    for(int i=1;i<n;i++)
    {
        int u,v;
        cin>>u>>v;
        adj[u].pub(v);
        adj[v].pub(u);
    }
    for(int i=1;i<=n;i++)
    {
        cin>>col[i];
        col[i]--;
        if(col[i]<k)
            vec[col[i]].pub(i);
    }
    for(int i=0;i<k;i++)
    {
        if(vec[i].empty())
            continue;
        int m=vec[i].size();
        for(int j=0;j<m;j++)
        {
            int u=vec[i][j];
            int nxi=(j+1)%m;
            nxn[u]=vec[i][nxi];
        }
    }

    ans=n;
    cnt=0;
    for(int i=1;i<=n;i++)
        see[i]=mark[i]=par[i]=siz[i]=0;
    for(int i=0;i<=k;i++)
        prt[i]=0;

    Solve(1);

    cout<<ans-1<<"\n";
}

int main()
{
    cin.tie(0);
    cout.tie(0);
    ios_base::sync_with_stdio(false);

    //ifstream fin("in.txt");
    //ofstream fout("out.txt");

    int t=1;
    //cin>>t;
    while(t--)
    {
        solve();
    }
}

#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...