Submission #1004849

#TimeUsernameProblemLanguageResultExecution timeMemory
1004849UnforgettableplCapital City (JOI20_capital_city)C++17
100 / 100
350 ms45204 KiB
// #pragma GCC optimize("O3")
#include <bits/stdc++.h>
using namespace std;

#define int long long

vector<int> adj[200001];
bool processed[200001];
int siz[200001];
bool current[200001];
int par[200001];
int visitedcol[200001];
int col[200001];
vector<int> nodes[200001];

void set_all(int x,int p,bool set_val){
    if(processed[x])return;
    current[x] = set_val;
    par[x] = p;
    for(int&i:adj[x])if(i!=p)set_all(i,x,set_val);
}

int cnt(int x,int p){
    if(processed[x])return 0;
    int ans = 1;
    for(int&i:adj[x])if(i!=p)ans+=cnt(i,x);
    return ans;
}

int centroid;

int centroidcalc(int x,int p,int tot){
    if(processed[x])return 0;
    siz[x] = 1;
    bool works = true;
    for(int&i:adj[x])if(i!=p){
        int res = centroidcalc(i,x,tot);
        siz[x]+=res;
        if(res>tot/2)works=false;
    }
    if(tot-siz[x]>tot/2)works=false;
    if(works){
        centroid = x;
        return 1e15;
    }
    return siz[x];
}

int calc(int x){
    centroidcalc(x,0,cnt(x,0));
    set_all(centroid,0,true); // Get them into processing stage
    queue<int> q; // Queue of colours
    q.emplace(col[centroid]);
    int ans = -1;
    while(!q.empty()){
        int curr = q.front();q.pop();
        if(visitedcol[curr]==centroid)continue;
        visitedcol[curr]=centroid;
        ans++;
        for(int&i:nodes[curr]){
            if(!current[i]){
                while(!q.empty())q.pop();
                ans = 1e15;
                break;
            }
            if(par[i])q.emplace(col[par[i]]);
        }
    }
    set_all(centroid,0,false);
    processed[centroid]=true;
    for(int&i:adj[centroid])if(!processed[i])ans=min(ans,calc(i));
    return ans;
}

int32_t main(){
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    int n,k;
    cin >> n >> k;
    for(int i=1;i<n;i++){
        int a,b;cin>>a>>b;
        adj[a].emplace_back(b);
        adj[b].emplace_back(a);
    }
    for(int i=1;i<=n;i++){
        cin>>col[i];
        nodes[col[i]].emplace_back(i);
    }
    siz[1]=n;
    cout << calc(1) << '\n';
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...