#include <bits/stdc++.h>
using namespace std;
void fastIO(){ios_base::sync_with_stdio(false), cin.tie(0);}
#define int long long
// #define cerr if (false) cerr
const int INF = (int)1e18;
const int MAXV = 500'020;
int V, Color;
vector<int> tree[MAXV];
int color[MAXV];
int lowT[MAXV], highT[MAXV], lca[MAXV];
int depth[MAXV];
int low[MAXV];
int toTree2[MAXV];
vector<int> tree2[MAXV];
struct Segtree {
int numLeaves;
vector<int> tree;
Segtree (int n) {
numLeaves = 1;
while (numLeaves < n) numLeaves *= 2;
tree.assign(2*numLeaves, INF);
}
int query (int l, int r) {
int res = INF;
l += numLeaves, r += numLeaves+1;
while (l < r) {
if (l&1) {
res = min(res, tree[l]);
l++;
}
if (r&1) {
--r;
res = min(res, tree[r]);
}
l /= 2;
r /= 2;
}
return res;
}
void update (int idx, int val) {
idx += numLeaves;
tree[idx] = val;
idx /= 2;
while (idx > 0) {
tree[idx] = min(tree[2*idx], tree[2*idx+1]);
idx /= 2;
}
}
};
Segtree minQ(4*MAXV);
int t = 0;
void dfs1 (int node, int parent) {
lowT[color[node]] = min(lowT[color[node]], t);
highT[color[node]] = max(highT[color[node]], t);
minQ.update(t, depth[node]);
t++;
for (int child : tree[node]) {
if (child == parent) continue;
depth[child] = depth[node] + 1;
dfs1(child, node);
lowT[color[node]] = min(lowT[color[node]], t);
highT[color[node]] = max(highT[color[node]], t);
minQ.update(t, depth[node]);
t++;
}
lowT[color[node]] = min(lowT[color[node]], t);
highT[color[node]] = max(highT[color[node]], t);
minQ.update(t, depth[node]);
t++;
}
stack<int> pathNodes;
int countTop = 0;
void dfs2 (int node, int parent) {
low[node] = lca[color[node]];
pathNodes.push(node);
for (int child : tree[node]) {
if (child == parent) continue;
dfs2(child, node);
low[node] = min(low[node], low[child]);
}
if (low[node] == depth[node]) {
low[node] = INF;
while (pathNodes.top() != node) {
toTree2[pathNodes.top()] = countTop;
pathNodes.pop();
}
assert(pathNodes.top() == node);
pathNodes.pop();
toTree2[node] = countTop;
countTop++;
}
}
void dfs3 (int node, int parent) {
for (int child : tree[node]) {
if (child == parent) continue;
if (toTree2[node] != toTree2[child]) {
tree2[toTree2[node]].push_back(toTree2[child]);
tree2[toTree2[child]].push_back(toTree2[node]);
}
dfs3(child, node);
}
}
signed main() {
fastIO();
cin >> V >> Color;
for (int i = 0; i < V-1; i++) {
int u, v;
cin >> u >> v;
u--, v--;
tree[u].push_back(v);
tree[v].push_back(u);
}
for (int i = 0; i < V; i++) {
cin >> color[i];
color[i]--;
}
fill(lowT, lowT+V, INF);
fill(highT, highT+V, 0);
dfs1(0, -1);
// for (int i = 0; i < Color; i++) cerr << lowT[i] << " " << highT[i] << "\n";
for (int i = 0; i < Color; i++) {
lca[i] = minQ.query(lowT[i], highT[i]);
}
// for (int i = 0; i < Color; i++) cerr << lca[i] << " ";
// cerr << "\n";
fill(low, low+V, INF);
dfs2(0, -1);
// for (int i = 0; i < V; i++) cerr << toTree2[i] << " ";
// cerr << "\n";
dfs3(0, -1);
if (countTop == 1) {
cout << "0\n";
return 0;
}
int numLeaves = 0;
for (int i = 0; i < countTop; i++) {
if (tree2[i].size() == 1) {
numLeaves++;
}
}
cout << (numLeaves/2) + (numLeaves%2) << "\n";
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |