This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
using namespace std;
int main()
{
cin.tie(0)->sync_with_stdio(0);
int n, k;
cin >> n >> k;
vector<vector<int>> g(n + 1), gg(k + 1);
for(int i = 1; i < n; i++) {
int a, b;
cin >> a >> b;
g[a].emplace_back(b);
g[b].emplace_back(a);
}
for(int i = 1; i <= n; i++) {
int s;
cin >> s;
gg[s].emplace_back(i);
}
int T = 0;
vector<int> st(n + 1), en(n + 1), dep(n + 1);
const int lg = 18;
vector<vector<int>> jump(lg + 1, vector<int>(n + 1));
vector<int> ord;
function<void(int, int)> dfs = [&](int u, int p) {
st[u] = ++T;
ord.push_back(u);
for(int to : g[u]) {
if(to == p) {
continue;
}
dep[to] = dep[u] + 1;
jump[0][to] = u;
dfs(to, u);
}
en[u] = T;
};
dfs(1, 1);
for(int i = 1; i <= lg; i++) {
for(int j = 1; j <= n; j++) {
jump[i][j] = jump[i - 1][jump[i - 1][j]];
}
}
auto anc = [&](int e, int dx) {
for(int i = 0; dx > 0; i++) {
if(dx % 2 == 1) {
e = jump[i][e];
}
dx /= 2;
}
return e;
};
auto lca = [&](int a, int b) {
if(dep[a] > dep[b]) {
swap(a, b);
}
b = anc(b, dep[b] - dep[a]);
if(a == b) {
return a;
}
for(int i = lg; i >= 0; i--) {
if(jump[i][a] != jump[i][b]) {
a = jump[i][a];
b = jump[i][b];
}
}
return jump[0][a];
};
vector<int> pre(n + 1);
for(int i = 1; i <= k; i++) {
sort(gg[i].begin(), gg[i].end(), [&](int ii, int jj) {
return st[ii] < st[jj];
});
for(int j = 0; j + 1 < (int)gg[i].size(); j++) {
int a = gg[i][j], b = gg[i][j + 1], c = lca(a, b);
pre[a]++, pre[b]++, pre[c] -= 2;
}
}
vector<int> par(n + 1);
function<int(int)> f = [&](int u) {
return u == par[u] ? u : par[u] = f(par[u]);
};
auto u = [&](int a, int b) {
par[f(a)] = f(b);
};
for(int i = 1; i <= n; i++) {
par[i] = i;
}
for(int i = n - 1; i >= 0; i--) {
int x = ord[i];
for(int to : g[x]) {
if(to != jump[0][x]) {
pre[x] += pre[to];
}
}
if(pre[x]) {
u(jump[0][x], x);
}
}
vector<int> deg(n + 1);
for(int i = 2; i <= n; i++) {
int a = f(jump[0][i]), b = f(i);
if(a != b) {
deg[a]++;
deg[b]++;
}
}
int ans = 0;
for(int i = 1; i <= n; i++) {
if(deg[i] == 1) {
ans++;
}
}
cout << (ans + 1) / 2 << '\n';
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |