제출 #1257318

#제출 시각아이디문제언어결과실행 시간메모리
1257318MisterReaper수도 (JOI20_capital_city)C++20
11 / 100
3101 ms334816 KiB
// File AA.cpp created on 13.08.2025 at 00:09:57
#include <bits/stdc++.h>

using i64 = long long;

#ifdef DEBUG 
    #include "/home/ahmetalp/Desktop/Workplace/debug.h"
#else
    #define debug(...) void(23)
#endif

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int N, K;
    std::cin >> N >> K;

    std::vector<std::vector<int>> adj(N);
    for (int i = 1; i < N; ++i) {
        int A, B;
        std::cin >> A >> B;
        --A, --B;
        adj[A].emplace_back(B);
        adj[B].emplace_back(A);
    }

    std::vector<std::vector<int>> cols(K);

    std::vector<int> C(N);
    for (int i = 0; i < N; ++i) {
        std::cin >> C[i];
        --C[i];
        cols[C[i]].emplace_back(i);
    }

    const int LG = std::__lg(N) + 1;

    std::vector<int> dep(N);
    std::vector<std::vector<int>> par(LG, std::vector<int>(N));
    auto dfs_init = [&](auto&& self, int v) -> void {
        for (auto u : adj[v]) {
            if (u == par[0][v]) {
                continue;
            }
            par[0][u] = v;
            dep[u] = dep[v] + 1;
            self(self, u);
        }
    };
    dfs_init(dfs_init, 0);

    for (int l = 0; l < LG - 1; ++l) {
        for (int v = 0; v < N; ++v) {
            par[l + 1][v] = par[l][par[l][v]];
        }
    }

    auto lca = [&](int u, int v) -> int {
        if (dep[u] < dep[v]) {
            std::swap(u, v);
        }
        int d = dep[u] - dep[v];
        for (int l = LG - 1; l >= 0; --l) {
            if (d >> l & 1) {
                u = par[l][u];
            }
        }
        if (u == v) {
            return u;
        }
        for (int l = LG - 1; l >= 0; --l) {
            if (par[l][u] != par[l][v]) {
                u = par[l][u];
                v = par[l][v];
            }
        }
        return par[0][v];
    };

    int node_cnt = K;
    const int MAX_NODE = N * LG + K;

    std::vector<std::vector<int>> ibdj(MAX_NODE);
    std::vector<std::vector<int>> lift(LG, std::vector<int>(N));
    for (int v = 0; v < N; ++v) {
        lift[0][v] = node_cnt++;
        // bdj[lift[0][v]].emplace_back(C[v]);
        ibdj[C[v]].emplace_back(lift[0][v]);
    }

    debug(__LINE__);

    for (int l = 0; l < LG - 1; ++l) {
        for (int v = 0; v < N; ++v) {
            lift[l + 1][v] = node_cnt++;
            // bdj[lift[l + 1][v]].emplace_back(lift[l][v]);
            // bdj[lift[l + 1][v]].emplace_back(lift[l][par[l][v]]);
            ibdj[lift[l][v]].emplace_back(lift[l + 1][v]);
            ibdj[lift[l][par[l][v]]].emplace_back(lift[l + 1][v]);
        }
    }

    debug(__LINE__);

    auto add_edge = [&](int c, int a, int b) -> void {
        // debug(c, a, b);
        int d = dep[a] - dep[b] + 1;
        for (int l = LG - 1; l >= 0; --l) {
            if (d >> l & 1) {
                // debug(l, a);
                // bdj[c].emplace_back(lift[l][a]);
                ibdj[lift[l][a]].emplace_back(c);
                a = par[l][a];
            }
        }
    };

    for (int c = 0; c < K; ++c) {
        // debug(c);
        // int p = 0;
        // while (C[p] != c) {
        //     p++;
        // }
        // auto dfs = [&](auto&& self, int v, int pr) -> bool {
        //     bool good = false;
        //     for (auto u : adj[v]) {
        //         if (u == pr) {
        //             continue;
        //         }
        //         good |= self(self, u, v);
        //     }
        //     if (C[v] == c) {
        //         good = true;
        //     }
        //     if (good && C[v] != c) {
        //         bdj[c].emplace_back(C[v]);
        //     }
        //     return good;
        // };
        // dfs(dfs, p, -1);
        int l = -1;
        for (auto v : cols[c]) {
            if (l == -1) {
                l = v;
            } else {
                l = lca(l, v);
            }
        }
        for (auto v : cols[c]) {
            add_edge(c, v, l);
        }
    }

    debug(__LINE__);

    // debug(bdj);

    int n = 0, tim = 0;
    std::vector<int> tin(node_cnt, -1), low(node_cnt, -1), bel(node_cnt, -1), siz(node_cnt), stk;
    auto dfs = [&](auto&& self, int v) -> void {
        tin[v] = low[v] = tim++;
        stk.emplace_back(v);
        for (auto u : ibdj[v]) {
            if (tin[u] == -1) {
                self(self, u);
                low[v] = std::min(low[v], low[u]);
            } else if (bel[u] == -1) {
                low[v] = std::min(low[v], tin[u]);
            }
        }
        if (tin[v] == low[v]) {
            int u;
            do {
                u = stk.back();
                stk.pop_back();
                bel[u] = n;
                siz[n] += (u < K);
            } while (u != v);
            n++;
        }
    };
    for (int i = 0; i < node_cnt; ++i) {
        if (tin[i] == -1) {
            dfs(dfs, i);
        }
    }

    tin.clear();
    low.clear();
    par.clear();

    debug(__LINE__);

    // debug(bel, siz);

    std::vector<int> inque(node_cnt), vis(node_cnt, -1);
    std::queue<int> que;

    auto add = [&](int v, int c) -> void {
        // debug(v, c);
        if (vis[v] == -1) {
            // debug("vis c");
            assert(++inque[v] <= 2);
            vis[v] = c;
            que.emplace(v);
        } else if (vis[v] != c && vis[v] != -2) {
            // debug("bad v");
            assert(++inque[v] <= 2);
            vis[v] = -2;
            que.emplace(v);
        }
    };

    for (int i = 0; i < N; ++i) {
        add(C[i], bel[C[i]]);
    }

    while (!que.empty()) {
        auto v = que.front();
        que.pop();
        // debug(v, vis[v]);
        for (auto u : ibdj[v]) {
            if (vis[v] == -2) {
                if (vis[u] != -2) {
                    vis[u] = -2;
                    assert(++inque[u] <= 2);
                    que.emplace(u);
                }
            } else {
                add(u, vis[v]);
            }
        }
    }

    std::vector<int> bad_cop(n);
    for (int i = 0; i < node_cnt; ++i) {
        if (vis[i] == -2) {
            bad_cop[bel[i]] = true;
        }
    }

    int ans = K + 1;
    for (int i = 0; i < n; ++i) {
        if (!bad_cop[i] && siz[i]) {
            // debug(siz[i]);
            ans = std::min(ans, siz[i]);
        }
    }

    std::cout << ans - 1 << '\n';

    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...