Submission #1370325

#TimeUsernameProblemLanguageResultExecution timeMemory
1370325sarahspeedyCapital City (JOI20_capital_city)C++20
100 / 100
369 ms45340 KiB
#include<bits/stdc++.h>
#define int long long
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define F first
#define S second
#define vi vector<int>
#define pii pair<int,int>

using namespace std;

const int mxn = 2e5 + 5;

int n, k;

int c[mxn];

int sz[mxn];

int vis[mxn];

int isc[mxn];

int cvis[mxn];

int par[mxn];

int ans = 1e9;

vi adj[mxn];

vi visc;

vi cc;

vector<vi> colnode;

void get_sz(int node, int from){

    sz[node] = 1;

    for(auto u : adj[node]){

        if(u != from && !vis[u]){

            get_sz(u, node);

            sz[node] += sz[u];
        }
    }
}

int get_centroid(int node, int from, int tar){

    for(auto u : adj[node]){

        if(u != from && !vis[u] && sz[u] > tar){

            return get_centroid(u, node, tar);
        }
    }

    return node;
}

void dfs(int node, int from){

    par[node] = from;

    cc.pb(node);

    isc[node] = 1;

    for(auto u : adj[node]){

        if(u != from && !vis[u]){

            dfs(u, node);
        }
    }
}

void centroid_decomp(int node){

    get_sz(node, -1);

    int ce =
    get_centroid(node, -1, sz[node] / 2);

    node = ce;

    cc.clear();

    dfs(node, -1);

    queue<int> q;

    q.push(node);

    cvis[node] = 1;

    int ct = 0;

    vi cols;

    bool ok = 1;

    while(!q.empty()){

        int cur = q.front();

        q.pop();

        if(!isc[cur]){

            ok = 0;

            break;
        }

        if(!visc[c[cur]]){

            visc[c[cur]] = 1;

            ct++;

            cols.pb(c[cur]);

            for(auto u : colnode[c[cur]]){

                if(!cvis[u]){

                    if(isc[u]){

                        cvis[u] = 1;
                    }
                    else{

                        ok = 0;

                        break;
                    }

                    q.push(u);
                }
            }
        }

        if(!ok) break;

        if(cur != ce){

            cur = par[cur];

            while(!cvis[cur]){

                if(isc[cur]){

                    cvis[cur] = 1;
                }
                else{

                    ok = 0;

                    break;
                }

                q.push(cur);

                cur = par[cur];
            }
        }

        if(!ok) break;
    }

    if(ok){

        ans = min(ans, ct);
    }

    for(auto u : cc){

        isc[u] = 0;

        cvis[u] = 0;
    }

    for(auto u : cols){

        visc[u] = 0;
    }

    vis[ce] = 1;

    for(auto u : adj[ce]){

        if(!vis[u]){

            centroid_decomp(u);
        }
    }
}

signed main(){

    ios::sync_with_stdio(false);

    cin.tie(NULL);

    cin >> n >> k;

    for(int i = 0; i < n - 1; i++){

        int a, b;

        cin >> a >> b;

        a--;
        b--;

        adj[a].pb(b);

        adj[b].pb(a);
    }

    for(int i = 0; i < n; i++){

        cin >> c[i];

        c[i]--;

        colnode.pb({});
    }

    colnode.assign(k, {});

    for(int i = 0; i < n; i++){

        colnode[c[i]].pb(i);
    }

    visc.resize(k);

    centroid_decomp(0);

    cout << ans - 1;
}
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...