제출 #883120

#제출 시각아이디문제언어결과실행 시간메모리
883120vjudge1수도 (JOI20_capital_city)C++17
100 / 100
1588 ms401144 KiB
// https://oj.uz/problem/view/JOI20_capital_city

#include <bits/stdc++.h>
using namespace std;

namespace std {

template <class Fun>
class y_combinator_result {
        Fun fun_;

       public:
        template <class T>
        explicit y_combinator_result(T &&fun) : fun_(std::forward<T>(fun)) {}

        template <class... Args>
        decltype(auto) operator()(Args &&...args) {
                return fun_(std::ref(*this), std::forward<Args>(args)...);
        }
};

template <class Fun>
decltype(auto) y_combinator(Fun &&fun) {
        return y_combinator_result<std::decay_t<Fun>>(std::forward<Fun>(fun));
}

/* Example
        auto fun = y_combinator([&](auto self, int x) -> void {
                self(x + 1);
        });
*/

}  // namespace std

int32_t main() {
        ios_base::sync_with_stdio(0);
        cin.tie(0);

        int n, k;
        cin >> n >> k;
        vector<vector<int>> adj(n);
        for (int i = 0; i < n - 1; i++) {
                int u, v;
                cin >> u >> v;
                u--, v--;
                adj[u].emplace_back(v);
                adj[v].emplace_back(u);
        }
        vector<int> c(n);
        for (int i = 0; i < n; i++) cin >> c[i], c[i]--;

        vector<int> d(n);
        vector<vector<int>> up(__lg(n) + 1, vector<int>(n));
        vector<vector<int>> rmq(__lg(n) + 1, vector<int>(n));
        vector<vector<int>> g(k);
        vector<vector<int>> rg(k);

        auto dfs = y_combinator([&](auto self, int u, int p) -> void {
                up[0][u] = p;
                rmq[0][u] = c[u];
                for (int i = 1; i <= __lg(n); i++) {
                        up[i][u] = up[i - 1][up[i - 1][u]];
                        if (up[i - 1][u] == u) continue;
                        rmq[i][u] = g.size();
                        g.emplace_back(vector<int>({rmq[i - 1][u], rmq[i - 1][up[i - 1][u]]}));
                }
                for (int v : adj[u]) {
                        if (v == p) continue;
                        d[v] = d[u] + 1;
                        self(v, u);
                }
        });

        auto lca = [&](int group_id, int x, int y) {
                if (d[x] < d[y]) swap(x, y);
                for (int i = __lg(n); i >= 0; i--) {
                        if (d[up[i][x]] >= d[y]) {
                                g[group_id].emplace_back(rmq[i][x]);
                                x = up[i][x];
                        }
                }
                if (x == y) {
                        g[group_id].emplace_back(rmq[0][x]);
                        return x;
                }
                for (int i = __lg(n); i >= 0; i--) {
                        if (up[i][x] != up[i][y]) {
                                g[group_id].emplace_back(rmq[i][x]);
                                g[group_id].emplace_back(rmq[i][y]);
                                x = up[i][x], y = up[i][y];
                        }
                }
                g[group_id].emplace_back(rmq[0][x]);
                g[group_id].emplace_back(rmq[1][y]);
                return up[0][x];
        };

        vector<vector<int>> group(k);

        for (int i = 0; i < n; i++) group[c[i]].emplace_back(i);

        dfs(0, 0);

        for (int i = 0; i < k; i++) {
                int top = group[i][0];
                for (int j : group[i]) top = lca(i, top, j);
        }

        int N = g.size();
        vector<int> low(N), num(N);
        int timer = 0;
        stack<int> st;
        int M = 0;
        vector<int> comp(N);
        vector<int> cnt;

        auto dfs2 = y_combinator([&](auto self, int u) -> void {
                low[u] = num[u] = ++timer;
                st.emplace(u);
                for (int v : g[u]) {
                        if (num[v]) {
                                low[u] = min(low[u], num[v]);
                        } else {
                                self(v);
                                low[u] = min(low[u], low[v]);
                        }
                }
                if (low[u] == num[u]) {
                        int v = -1;
                        cnt.emplace_back(0);
                        do {
                                v = st.top();
                                comp[v] = M;
                                st.pop();
                                num[v] = low[v] = N + 1;
                                cnt[M] += v < k;
                        } while (v != u);
                        M++;
                }
        });

        for (int i = 0; i < N; i++) {
                if (!num[i]) dfs2(i);
        }

        for (int i = 0; i < N; i++) {
                for (int j : g[i]) {
                        if (comp[i] == comp[j]) continue;
                        cnt[comp[i]] = n;
                }
        }

        int res = k;

        for (int i = 0; i < k; i++) {
                if (cnt[comp[i]]) res = min(res, cnt[comp[i]] - 1);
        }

        cout << res;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...