Submission #1276825

#TimeUsernameProblemLanguageResultExecution timeMemory
1276825duckindogMergers (JOI19_mergers)C++20
0 / 100
49 ms15948 KiB
#include <bits/stdc++.h>

using namespace std;

const int N = 500'000 + 10;
int n, k;
vector<int> ad[N];
vector<int> state[N];

int st[N], ed[N], num;
int par[N];
void dfs(int u, int p = -1) { 
    st[u] = ++num;
    for (const auto& v : ad[u]) { 
        if (v == p) continue;
        par[v] = u;
        dfs(v, u);
    }
    ed[u] = num;
}
inline bool anc(int u, int v) { return st[u] <= st[v] && ed[v] <= ed[u]; }

int id[N];
int root(int u) { return id[u] < 0 ? u : id[u] = root(id[u]); }
void join(int u, int v) { 
    u = root(u); v = root(v);
    if (anc(v, u)) swap(u, v);
    if (u == v) return;
    id[u] += id[v];
    id[v] = u;
}

int32_t main() { 
    cin.tie(0)->sync_with_stdio(0);

    cin >> n >> k;
    for (int i = 1; i < n; ++i) { 
        int u, v; cin >> u >> v;
        ad[u].push_back(v);
        ad[v].push_back(u);
    }
    for (int i = 1; i <= n; ++i) { 
        int s; cin >> s;
        state[s].push_back(i);
    }

    dfs(par[1] = 1);

    memset(id, -1, sizeof id);

    for (int i = 1; i <= k; ++i) { 
        if (!state[i].size()) continue;
        int u = state[i][0];
        for (auto v : state[i]) { 
            while (!anc(u, v)) { 
                join(u, par[u]);
                u = root(u);
            }
            while (v != u) { 
                join(v, par[v]);
                v = root(v);
            }
        }
    }

    vector<pair<int, int>> edges;
    for (int i = 1; i <= n; ++i) { 
        for (const auto& v : ad[i]) { 
            if (root(i) < root(v)) edges.push_back({root(i), root(v)});
        }
        ad[i].clear();
    }
    sort(edges.begin(), edges.end());
    edges.erase(unique(edges.begin(), edges.end()), edges.end());

    for (const auto& [u, v] : edges) { 
        ad[u].push_back(v);
        ad[v].push_back(u);
    }

    int cntLeaf = 0;
    for (int i = 1; i <= n; ++i) { 
        if (i == root(i)) cntLeaf += ((int)ad[i].size() == 1);
    }
    cout << max(0, cntLeaf - 1) << "\n";
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...