#include <bits/stdc++.h>
using namespace std;
vector<vector<int>> adj, parent;
vector<int> depth;
void dfs(int node, int prev, int dis) {
depth[node] = dis;
parent[node][0] = prev;
for(auto child: adj[node]) if (child != prev) dfs(child, node, dis+1);
}
int ancestor(int x, int k) {
int cnt = 0;
while(k && x) {
if (k&1) x = parent[x][cnt];
k = k>>1; cnt++;
}
return x;
}
int lca(int a, int b) {
if (depth[a] < depth[b]) swap(a, b);
a = ancestor(a, depth[a]-depth[b]);
if (a==b) return a;
for(int k=19; k>=0; k--) {
int aa = parent[a][k];
int bb = parent[b][k];
if (aa != bb) {
a = aa; b = bb;
}
}
return parent[a][0];
}
vector<int> componentIds;
vector<set<int>> adjComponents;
void dfsComponents(int node, int prev) {
if (prev && componentIds[prev] != componentIds[node]) {
adjComponents[componentIds[prev]].insert(componentIds[node]);
adjComponents[componentIds[node]].insert(componentIds[prev]);
}
for(auto child: adj[node]) if (child != prev) {
dfsComponents(child, node);
}
}
struct DSU {
vector<int> e;
DSU(int N) { e = vector<int>(N, -1); }
// get representive component (uses path compression)
int get(int x) { return e[x] < 0 ? x : e[x] = get(e[x]); }
bool same_set(int a, int b) { return get(a) == get(b); }
int size(int x) { return -e[get(x)]; }
bool unite(int x, int y) { // union by size
x = get(x), y = get(y);
if (x == y) return false;
if (e[x] > e[y]) swap(x, y);
e[x] += e[y]; e[y] = x;
return true;
}
};
signed main() {
ios::sync_with_stdio(0);
cin.tie(0);
int n, k; cin >> n >> k;
adj = vector<vector<int>>(n+1);
parent = vector<vector<int>>(n+1, vector<int>(20, 0));
depth = vector<int>(n+1, 0);
for(int i=0; i<n-1; i++) {
int u, v; cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs(1, 0, 0);
for(int k=1; k<20; k++) {
for(int i=1; i<=n; i++) {
if (parent[i][k-1]) parent[i][k] = parent[parent[i][k-1]][k-1];
}
}
DSU dsu(n+1);
vector<int> states(n+1);
vector<bool> stateVisited(k+1, false), nodeVisited(n+1, false);
vector<vector<int>> st(k+1);
for(int i=1; i<=n; i++) {
cin >> states[i];
st[states[i]].push_back(i);
}
for(int i=1; i<=k; i++) {
if (stateVisited[i]) continue;
stateVisited[i] = true;
vector<int> queue;
for(auto node: st[i]) {
if (nodeVisited[node]) continue;
nodeVisited[node] = true;
queue.push_back(node);
}
auto addState = [&](int state) {
if (stateVisited[state]) return;
stateVisited[state] = true;
for(auto child: st[state]) {
if (nodeVisited[child]) continue;
nodeVisited[child] = true;
queue.push_back(child);
}
};
int anc = 0, idx = 0;
while(idx < queue.size()) {
int queueSize = queue.size();
for(int j=idx; j<queueSize; j++) {
int node = queue[j];
if (anc == 0) anc = node;
else anc = lca(anc, node);
}
while(idx < queueSize) {
int node = queue[idx++];
while(node != anc) {
addState(states[node]);
dsu.unite(node, parent[node][0]);
node = parent[node][0];
}
}
addState(states[anc]);
}
}
componentIds = vector<int>(n+1);
adjComponents = vector<set<int>>(n+1);
set<int> compIds;
for(int i=1; i<=n; i++) {
componentIds[i] = dsu.get(i);
compIds.insert(componentIds[i]);
}
dfsComponents(1, 0);
int ans = 0;
for(auto id: compIds) {
if (adjComponents[id].size() == 1) {
ans++;
}
}
ans = (ans + 1)/2;
cout << ans << '\n';
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |