Submission #1138313

#TimeUsernameProblemLanguageResultExecution timeMemory
1138313OI_AccountCapital City (JOI20_capital_city)C++20
100 / 100
2690 ms454040 KiB
#pragma GCC optimize("O2,unroll-loops,Ofast")
#include <bits/stdc++.h>
using namespace std;

const int N = 200'000;
const int S = N * 19;
const int M = 17;

int n, k, c[N + 10], ver;
int h[N + 10], dp[N + 10][M + 1], idSp[N + 10][M + 1];
int cntCmp, id[S + 10], mn[S + 10], mx[S + 10];
vector<int> adj[N + 10], vec, pack[N + 10];
vector<int> out[S + 10], in[S + 10], good;
int sum[N + 10];
bool notOut[N + 10];

void readInput() {
    cin >> n >> k; 
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    for (int i = 1; i <= n; i++) {
        cin >> c[i];
        pack[c[i]].push_back(i);
    }
}

void calcIdSp() {
    int pnt = k;
    for (int i = 1; i <= n; i++)
        for (int j = 0; j <= M; j++) {
            idSp[i][j] = ++pnt;
            out[pnt].reserve(2);
        }
    ver = pnt;
}

void addEdge(int u, int v) {
    out[u].push_back(v);
    //in[v].push_back(u);
}

void dfs(int u = 1, int par = 0) {
    h[u] = h[par] + 1;
    dp[u][0] = par;
    addEdge(idSp[u][0], c[u]);
    for (int j = 1; j <= M && dp[u][j - 1]; j++) {
        dp[u][j] = dp[dp[u][j - 1]][j - 1];
        addEdge(idSp[u][j], idSp[u][j - 1]);
        addEdge(idSp[u][j], idSp[dp[u][j - 1]][j - 1]);
    }
    for (auto v: adj[u])
        if (v != par)
            dfs(v, u);
}

int LCA(int u, int v) {
    if (h[u] < h[v])
        swap(u, v);
    for (int j = M; j >= 0; j--)
        if (h[u] - h[v] >= (1 << j))
            u = dp[u][j];
    if (u == v)
        return u;
    for (int j = M; j >= 0; j--)
        if (dp[u][j] != dp[v][j]) {
            u = dp[u][j];
            v = dp[v][j];
        }
    return dp[u][0];
}

void addEdges(int x, int u, int v) {
    /*while (h[u] - h[v] >= 2) {
        addEdge(x, idSp[u][1]);
        u = dp[u][1];
    }*/
    /*while (u != v) {
        addEdge(x, idSp[u][0]);
        u = dp[u][0];
    }*/
    for (int j = M; j >= 0; j--)
        if (h[u] - h[v] >= (1 << j)) {
            addEdge(x, idSp[u][j]);
            u = dp[u][j];
        }
    addEdge(x, c[v]);
}

void calcGraph() {
    for (int i = 1; i <= k; i++) {
        int lca = pack[i][0];
        for (int j = 1; j < pack[i].size(); j++)
            lca = LCA(lca, pack[i][j]);
        for (auto u: pack[i])
            addEdges(c[u], u, lca);
    }
    for (int i = 1; i <= ver; i++)
        out[i].shrink_to_fit();
}

inline void dfsOut(int u) {
    id[u] = -1;
    for (auto v: out[u])
        if (!id[v])
            dfsOut(v);
    vec.push_back(u);
}

inline void dfsIn(int u) {
    id[u] = cntCmp;
    for (auto v: in[u])
        if (id[v] == -1)
            dfsIn(v);
}

void SCC() {    
    vec.reserve(ver);
    for (int i = 1; i <= ver; i++)
        if (!id[i])
            dfsOut(i);
    reverse(vec.begin(), vec.end());
    for (int i = k + 1; i <= ver; i++) {
        for (auto j: out[i])
            in[j].push_back(i);
        out[i].clear();
        out[i].shrink_to_fit();
    }
    for (int i = 1; i <= k; i++)
        for (auto j: out[i])
            in[j].push_back(i);
    for (auto u: vec)
        if (id[u] == -1) {
            cntCmp++;
            dfsIn(u);
        }
}

void dfsMinMax(int u = 1, int par = 0) {
    mn[idSp[u][0]] = mx[idSp[u][0]] = id[c[u]];
    for (int j = 1; j <= M && dp[u][j - 1]; j++) {
        mn[idSp[u][j]] = min(mn[idSp[u][j - 1]], mn[idSp[dp[u][j - 1]][j - 1]]);
        mx[idSp[u][j]] = max(mx[idSp[u][j - 1]], mx[idSp[dp[u][j - 1]][j - 1]]);
    }
    for (auto v: adj[u])
        if (v != par)
            dfsMinMax(v, u);
}

int getIdx(int x) {
    return lower_bound(good.begin(), good.end(), x) - good.begin();
}

void init() {
    for (int u = 1; u <= k; u++)
        good.push_back(id[u]);
    sort(good.begin(), good.end());
    good.resize(unique(good.begin(), good.end()) - good.begin());
    for (int u = 1; u <= k; u++) {
        int idx = getIdx(id[u]);
        sum[idx]++;
        for (auto v: out[u])
            if (v <= k)
                notOut[idx] |= (id[v] != id[u]);
            else
                notOut[idx] |= (mn[v] != id[u] || mx[v] != id[u]);
    }
}

void calcAns() {
    int ans = k;
    for (int i = 0; i < good.size(); i++)
        if (notOut[i] == false && sum[i])
            ans = min(ans, sum[i] - 1);
    cout << ans << flush;
}

int main() {
    ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    readInput();
    calcIdSp();
    dfs();
    calcGraph();
    SCC();
    dfsMinMax();
    init();
    calcAns();
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...