제출 #1276923

#제출 시각아이디문제언어결과실행 시간메모리
1276923duckindog수도 (JOI20_capital_city)C++20
100 / 100
383 ms34040 KiB
#include <bits/stdc++.h>

using namespace std;

const int N = 200'000 + 10,
            MAX = 1'000'000;
int n, k;
vector<int> ad[N];
int c[N];

bool mk[N];

int sz[N], totalSize;
void initDFS(int u, int p = -1) { 
    sz[u] = 1;
    for (const auto& v : ad[u]) { 
        if (mk[v] || v == p) continue;
        initDFS(v, u);
        sz[u] += sz[v];
    }
    totalSize = sz[u];
}
int centroid(int u, int p = -1) { 
    for (const auto& v : ad[u]) { 
        if (mk[v] || v == p) continue;
        if (sz[v] > totalSize / 2) return centroid(v, u);
    }
    return u;
}

int par[N];
vector<int> nodes;
void collectDFS(int u, int p = -1) { 
    nodes.push_back(u);
    for (const auto& v : ad[u]) { 
        if (mk[v] || v == p) continue;
        par[v] = u;
        collectDFS(v, u);
    }
}

int id[N];
int root(int u) { return id[u] < 0 ? u : id[u] = root(id[u]); }
void join(int u, int v) { 
    u = root(u); v = root(v);
    if (id[u] > id[v]) swap(u, v);
    if (u == v) return;
    id[u] += id[v];
    id[v] = u;
}

vector<int> pos[N];
int cntC[N];
bool mkC[N];

int answer = MAX;
void dfs(int u, int p = -1) { 
    initDFS(u, p);
    mk[u = centroid(u, p)] = true;

    nodes = {u};
    for (const auto& v : ad[u]) { 
        if (mk[v] || v == p) continue;
        par[v] = u;
        collectDFS(v, u);
    }

    for (const auto& x : nodes) cntC[c[x]] += 1;

    int totalC = 0;

    queue<int> q({c[u]});
    mkC[c[u]] = true;
    while (q.size()) { 
        auto color = q.front(); q.pop();
        totalC += 1;
        if (cntC[color] != (int)pos[color].size()) { 
            totalC = MAX;
            break;
        }

        for (auto x : pos[color]) { 
            for (; root(x) != root(u); x = par[x]) {
                if (!mkC[c[x]]) { 
                    q.push(c[x]);
                    mkC[c[x]] = true;
                }
                join(x, u);
            }
        }
    }
    answer = min(answer, totalC - 1);

    for (const auto& x : nodes) {
        cntC[c[x]] = 0;
        mkC[c[x]] = false;
        id[x] = -1;
    }

    for (const auto& v : ad[u]) { 
        if (mk[v] || v == p) continue;
        dfs(v, u);
    }
}

int32_t main() { 
    cin.tie(0)->sync_with_stdio(0);

    cin >> n >> k;
    for (int i = 1; i < n; ++i) { 
        int u, v; cin >> u >> v;
        ad[u].push_back(v);
        ad[v].push_back(u);
    }
    for (int i = 1; i <= n; ++i) cin >> c[i];

    for (int i = 1; i <= n; ++i) pos[c[i]].push_back(i);
    memset(id, -1, sizeof id);

    dfs(1);

    cout << answer << "\n";
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...