제출 #1239898

#제출 시각아이디문제언어결과실행 시간메모리
1239898damoon수도 (JOI20_capital_city)C++20
100 / 100
419 ms38004 KiB
#include <bits/stdc++.h>
using namespace std;

//#pragma GCC optimize("O3,unroll-loops") //main
//#pragma GCC target("avx2") //cf...
//#pragma GCC target("sse4") //Quera

#define ll long long
typedef pair<int,int> pii;
typedef pair<int,pii> pip;
typedef pair<pii,int> ppi;
typedef pair<pii,pii> ppp;
#define f first
#define s second

#define lc 2*id
#define rc 2*id+1
#define all(x) x.begin(),x.end()

#define pb push_back
#define pp pop_back
#define unicorn(x) x.resize(unique(x.begin(),x.end())-x.begin())

string pr(int* vv,int l,int r){for(int i=l;i<r;i++)cout<<vv[i]<<" ";return "";}
string pr( ll* vv,int l,int r){for(int i=l;i<r;i++)cout<<vv[i]<<" ";return "";}
string pr(vector<int> vv){for(auto i:vv)cout<<i<<" ";return "";}
string pr( vector<ll> vv){for(auto i:vv)cout<<i<<" ";return "";}
string pr(pii* vv,int l,int r){for(int i=l;i<r;i++)cout<<"( "<<vv[i].f<<","<<vv[i].s<<" )    ";return "";}
string pr(vector<pii> vv){for(auto i:vv)cout<<"( "<<i.f<<","<<i.s<<" )    ";return "";}

random_device device;
default_random_engine rng(device());
#define randt(a,b) uniform_int_distribution<int64_t>(a,b)(rng)

const int L = 2e5+10,mod = 1e9+7;
const int inf = 1e9+10;
int n,k,C;
int ans;
int a[L];
vector<int> occ[L],adj[L];
int cnt[L],subt[L];
bool mark[L],isc[L],isv[L];
vector<int> V;
int par[L];

void Find(int v,int p,int sz){
    bool ok = 1;
    subt[v] = 1;
    for(auto u:adj[v]){
        if(!mark[u] and u != p){
            Find(u,v,sz);
            subt[v] += subt[u];
            ok = ok and (subt[u]*2 <= sz);
        }
    }
    ok = ok and (subt[v]*2 >= sz);
    if(ok)
        C = v;
}

void reset(int v,int p){
    cnt[a[v]] = isc[a[v]] = 0;
    isv[v] = 0;
    for(auto u:adj[v]){
        if(!mark[u] and u != p){
            reset(u,v);
        }
    }
}

void dfs(int v){
    cnt[a[v]]++;
    subt[v] = 1;
    for(auto u:adj[v]){
        if(!mark[u] and u != par[v]){
            par[u] = v;
            dfs(u);
            subt[v] += subt[u];
        }
    }
}

void solve(int w,int sz){
    reset(w,0);
    Find(w,0,sz);
    par[C] = 0;
    dfs(C);
    //cout<<"C: "<<C<<"  "<<w<<"  "<<sz<<endl;
    mark[C] = 1;
    V.clear();
    V.pb(a[C]);
    int res = 0;
    while(V.size()){
        int c = V.back();
        V.pp();
        if(isc[c])
            continue;
        if(cnt[c] != occ[c].size()){
            res = inf;
            break;
        }
        //cout<<"add: "<<c<<endl;
        isc[c] = 1;
        res++;
        for(auto v:occ[c]){
            int cur = v;
            while(!isv[cur] and cur){
                isv[cur] = 1;
                V.pb(a[cur]);
                cur = par[cur];
            }
        }
    }
    //cout<<"res: "<<res<<endl;
    ans = min(ans,res);

    //cout<<"----------------------"<<endl;
    for(auto u:adj[C]){
        if(!mark[u]){
            solve(u,subt[u]);
        }
    }
}

int main(){
    //ofstream cout ("out.out");
    //ifstream cin ("in.in");

    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    cin>>n>>k;
    for(int i=1;i<n;i++){
        int u,v;
        cin>>u>>v;
        adj[u].pb(v);
        adj[v].pb(u);
    }
    for(int i=1;i<=n;i++){
        cin>>a[i];
        occ[a[i]].pb(i);
    }

    ans = inf;
    solve(1,n);

    cout<<ans-1<<endl;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...