Submission #1319578

#TimeUsernameProblemLanguageResultExecution timeMemory
1319578Wansur수도 (JOI20_capital_city)C++20
11 / 100
1384 ms459960 KiB
#include <bits/stdc++.h>
#define ent '\n'
// #define int long long

using namespace std;

const int maxn = 200'020;
const int lk = 20;

vector<int> g[maxn], e[maxn * lk], rev[maxn * lk];
int up[lk][maxn], go[lk][maxn];
int c[maxn], tin[maxn], tout[maxn], lca[maxn];
int n, k, N, timer;

void dfs(int v, int p) {
    tin[v] = ++timer;

    if(N + 30 >= maxn * lk) {
        while(1) cout << 1;
    }

    up[0][v] = p;
    go[0][v] = ++N;
    e[go[0][v]].push_back(c[v]);

    for(int i = 1; i < 18; i++) {
        up[i][v] = up[i - 1][up[i - 1][v]];
        if(up[i - 1][v] != 0) {
            go[i][v] = ++N;
            e[go[i][v]].push_back(go[i - 1][v]);
            e[go[i][v]].push_back(go[i - 1][up[i - 1][v]]);
        }
    }

    for(int to : g[v]) {
        if(to != p) {
            dfs(to, v);
        }
    }

    tout[v] = timer;
}

bool check(int u, int v) {
    return tin[u] <= tin[v] && tout[v] <= tout[u];
}

int get_lca(int u, int v) {
    if(check(u, v)) return u;
    if(check(v, u)) return v;

    for(int i = 17; i >= 0; i--) {
        if(up[i][v] != 0 && !check(up[i][v], u)) {
            v = up[i][v];
        }
    }

    return up[0][v];
}

vector<int> ord;
int comp[maxn], used[maxn], fg[maxn], sz[maxn];

void dfs(int v) {
    used[v] = true;
    for(int to : e[v]) {
        if(!used[to]) {
            dfs(to);
        }
    }

    ord.push_back(v);
}

void calc(int v, int col) {
    comp[v] = col;

    for(int to : rev[v]) {
        if(!comp[to]) {
            calc(to, col);
        }
    }
}

void solve() {
    cin >> n >> k;
    N = k;

    for(int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }

    for(int i = 1; i <= n; i++) {
        cin >> c[i];
    }

    dfs(1, 0);

    for(int i = 1; i <= n; i++) {
        if(!lca[c[i]]) {
            lca[c[i]] = i;
        }

        lca[c[i]] = get_lca(lca[c[i]], i);
    }

    for(int v = 1; v <= n; v++) {
        int cur = v;
        for(int i = 17; i >= 0; i--) {
            if(up[i][cur] != 0 && !check(up[i][cur], lca[c[v]])) {
                e[c[v]].push_back(go[i][cur]);
                cur = up[i][cur];
            }
        }

        if(cur != lca[c[v]]) {
            e[c[v]].push_back(go[1][cur]);
        }
        else {
            e[c[v]].push_back(go[0][cur]);
        }
    }

    for(int v = 1; v <= N; v++) {
        for(int to : e[v]) {
            rev[to].push_back(v);
        }
    }

    for(int v = 1; v <= N; v++) {
        if(!used[v]) {
            dfs(v);
        }
    }

    reverse(ord.begin(), ord.end());

    int cnt = 0;
    for(int v : ord) {
        if(!comp[v]) {
            calc(v, ++cnt);
        }
    }

    for(int v = 1; v <= k; v++) {
        sz[comp[v]]++;
    }

    int ans = 1e9;
    for(int v = 1; v <= N; v++) {
        for(int to : e[v]) {
            if(comp[v] != comp[to]) {
                fg[comp[v]] = true;
            }
        }
    }

    for(int i = 1; i <= cnt; i++) {
        if(!fg[i] && sz[i] > 0) {
            ans = min(ans, sz[i]);
        }
    }

    cout << ans - 1 << ent;
}

int32_t main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    int t = 1;
    // cin >> t;
    while(t--) {
        solve();
    }
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...