Submission #595407

#TimeUsernameProblemLanguageResultExecution timeMemory
595407elkernosMergers (JOI19_mergers)C++17
100 / 100
850 ms139500 KiB
#include <bits/stdc++.h> using namespace std; int main() { cin.tie(0)->sync_with_stdio(0); int n, k; cin >> n >> k; vector<vector<int>> g(n + 1), gg(k + 1); for(int i = 1; i < n; i++) { int a, b; cin >> a >> b; g[a].emplace_back(b); g[b].emplace_back(a); } for(int i = 1; i <= n; i++) { int s; cin >> s; gg[s].emplace_back(i); } int T = 0; vector<int> st(n + 1), en(n + 1), dep(n + 1); const int lg = 18; vector<vector<int>> jump(lg + 1, vector<int>(n + 1)); vector<int> ord; function<void(int, int)> dfs = [&](int u, int p) { st[u] = ++T; ord.push_back(u); for(int to : g[u]) { if(to == p) { continue; } dep[to] = dep[u] + 1; jump[0][to] = u; dfs(to, u); } en[u] = T; }; dfs(1, 1); for(int i = 1; i <= lg; i++) { for(int j = 1; j <= n; j++) { jump[i][j] = jump[i - 1][jump[i - 1][j]]; } } auto anc = [&](int e, int dx) { for(int i = 0; dx > 0; i++) { if(dx % 2 == 1) { e = jump[i][e]; } dx /= 2; } return e; }; auto lca = [&](int a, int b) { if(dep[a] > dep[b]) { swap(a, b); } b = anc(b, dep[b] - dep[a]); if(a == b) { return a; } for(int i = lg; i >= 0; i--) { if(jump[i][a] != jump[i][b]) { a = jump[i][a]; b = jump[i][b]; } } return jump[0][a]; }; vector<int> pre(n + 1); for(int i = 1; i <= k; i++) { sort(gg[i].begin(), gg[i].end(), [&](int ii, int jj) { return st[ii] < st[jj]; }); for(int j = 0; j + 1 < (int)gg[i].size(); j++) { int a = gg[i][j], b = gg[i][j + 1], c = lca(a, b); pre[a]++, pre[b]++, pre[c] -= 2; } } vector<int> par(n + 1); function<int(int)> f = [&](int u) { return u == par[u] ? u : par[u] = f(par[u]); }; auto u = [&](int a, int b) { par[f(a)] = f(b); }; for(int i = 1; i <= n; i++) { par[i] = i; } for(int i = n - 1; i >= 0; i--) { int x = ord[i]; for(int to : g[x]) { if(to != jump[0][x]) { pre[x] += pre[to]; } } if(pre[x]) { u(jump[0][x], x); } } vector<int> deg(n + 1); for(int i = 2; i <= n; i++) { int a = f(jump[0][i]), b = f(i); if(a != b) { deg[a]++; deg[b]++; } } int ans = 0; for(int i = 1; i <= n; i++) { if(deg[i] == 1) { ans++; } } cout << (ans + 1) / 2 << '\n'; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...