제출 #1257447

#제출 시각아이디문제언어결과실행 시간메모리
1257447MisterReaper수도 (JOI20_capital_city)C++20
41 / 100
3096 ms31172 KiB
// File AA.cpp created on 13.08.2025 at 11:06:04
#include <bits/stdc++.h>

using i64 = long long;

#ifdef DEBUG 
    #include "/home/ahmetalp/Desktop/Workplace/debug.h"
#else
    #define debug(...) void(23)
#endif

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int N, K;
    std::cin >> N >> K;

    std::vector<std::vector<int>> adj(N);
    for (int i = 1; i < N; ++i) {
        int A, B;
        std::cin >> A >> B;
        --A, --B;
        adj[A].emplace_back(B);
        adj[B].emplace_back(A);
    }

    std::vector<std::vector<int>> cols(K);
    std::vector<int> colscnt(K);

    std::vector<int> C(N);
    for (int i = 0; i < N; ++i) {
        std::cin >> C[i];
        --C[i];
        cols[C[i]].emplace_back(i);
    }

    std::vector<int> siz(N), act(N, true), vis(N, false), par(N);
    auto calc_sizes = [&](auto&& self, int v, int pr) -> void {
        siz[v] = 1;
        for (auto u : adj[v]) {
            if (u == pr || !act[v]) {
                continue;
            }
            self(self, u, v);
            siz[v] += siz[u];
        }
    };

    auto find_centro = [&](auto&& self, int v, int pr, int gs) -> int {
        for (auto u : adj[v]) {
            if (u == pr || !act[v]) {
                continue;
            }
            if (siz[u] * 2 > gs) {
                return self(self, u, v, gs);
            }
        }
        return v;
    };

    int ans = K + 1;

    std::vector<int> vtx;
    auto dfs = [&](auto&& self, int v, int pr) -> void {
        par[v] = pr;
        vtx.emplace_back(v);
        for (auto u : adj[v]) {
            if (u == pr || !act[u]) {
                continue;
            }
            self(self, u, v);
        }
    };

    auto dnq = [&](auto&& self, int v) -> void {
        calc_sizes(calc_sizes, v, v);
        v = find_centro(find_centro, v, v, siz[v]);
        vtx.clear();
        dfs(dfs, v, v);
        for (auto u : vtx) {
            vis[u] = false;
            colscnt[C[u]]++;
        }

        if (colscnt[C[v]] == cols[C[v]].size()) {
            int cnt = 1;
            std::vector<int> stk;
            for (auto u : cols[C[v]]) {
                assert(vis[u] == false);
                vis[u] = true;
                stk.emplace_back(u);
            }
            while (!stk.empty()) {
                auto u = stk.back();
                stk.pop_back();
                u = par[u];
                while (!vis[u]) {
                    if (colscnt[C[u]] != cols[C[u]].size()) {
                        cnt = K + 1;
                        break;
                    }
                    for (auto w : cols[C[u]]) {
                        assert(vis[w] == false);
                        vis[w] = true;
                        stk.emplace_back(w);
                    }
                    cnt++;
                    u = par[u];
                }
                if (cnt == K + 1) {
                    break;
                }
            }

            ans = std::min(ans, cnt);
        }

        for (auto u : vtx) {
            colscnt[C[u]]--;
        }

        #ifdef LOCAL
            assert(std::count(colscnt.begin(), colscnt.end(), 0) == K);
        #endif

        act[v] = false;
        for (auto u : adj[v]) {
            if (!act[u]) {
                continue;
            }
            self(self, u);
        }
    };
    dnq(dnq, 0);

    std::cout << ans - 1 << '\n';

    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...