#include <iostream>
#include <vector>
#include <algorithm>
#include <numeric>
using namespace std;
const int MAXN = 5e5;
int N, K;
int dist[MAXN];
vector<int> adj[MAXN];
int group_of[MAXN];
vector<int> members[MAXN];
struct lcatreeitem {
int i;
lcatreeitem& operator=(int other) {
i = other;
return *this;
}
operator int() const {
return i;
}
bool operator<(const lcatreeitem& other) const {
if (i == -1) return false;
else if (other.i == -1) return true;
assert(i >= 0 && other.i >= 0 && i < N && other.i < N);
else return dist[i] < dist[other.i];
}
} lcatree[1<<20];
int first[MAXN];
int lcatree_i = 0;
const int lcatree_shift = 1<<19;
int mergeuntil[MAXN];
bool penible[MAXN];
int uf[MAXN];
int root(int u) {
return uf[u] == u ? u : uf[u] = root(uf[u]);
}
void merge(int u, int v) {
if (dist[u] > dist[v]) {
swap(u, v);
}
uf[root(v)] = root(u);
}
void dfs(int u) {
first[u] = lcatree_i;
lcatree[lcatree_shift+lcatree_i++] = u;
for (int v : adj[u]) {
adj[v].erase(find(adj[v].begin(), adj[v].end(), u));
dist[v] = dist[u]+1;
dfs(v);
lcatree[lcatree_shift+lcatree_i++] = u;
}
}
int dfs2(int u) {
int ret = mergeuntil[u];
for (int v : adj[u]) {
int r = dfs2(v);
if (dist[r] < dist[ret]) ret = r;
if (r != v) {
merge(v, u);
}
}
return ret;
}
int lca(int u, int v) {
if (first[u] > first[v]) swap(u, v);
int l = lcatree_shift+first[u], r = lcatree_shift+first[v]+1;
lcatreeitem ret;
ret = -1;
for (; l < r; l >>= 1, r >>= 1) {
if (l&1) {
ret = min(ret, lcatree[l++]);
}
if (r&1) {
ret = min(ret, lcatree[--r]);
}
}
//cerr << "lca(" << u << ", " << v << ") = " << (int)ret << '\n';
return ret;
}
int dfs3(int u) {
int ret = 0;
for (int v : adj[u]) {
ret += dfs3(v);
}
ret = max(ret, (int)penible[u]);
return ret;
}
signed main() {
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
cin >> N >> K;
//dist.resize(N);
//adj.resize(N);
//first.resize(N);
//group_of.resize(N);
//members.resize(K);
//lcatree.resize(1<<20);
//uf.resize(N);
//mergeuntil.resize(N);
//penible.resize(N);
iota(mergeuntil, mergeuntil+N, 0);
iota(uf, uf+N, 0);
int u, v;
for (int i = 0; i < N-1; i++) {
cin >> u >> v, u--, v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
for (int i = 0; i < N; i++) {
cin >> group_of[i], group_of[i]--;
members[group_of[i]].push_back(i);
}
dfs(0);
for (int i = lcatree_shift-1; i > 0; i--) {
lcatree[i] = min(lcatree[2*i], lcatree[2*i+1]);
}
assert(N < 5e5);
for (int g = 0; g < K; g++) {
int a = members[g][0];
for (auto m : members[g]) {
a = lca(a, m);
}
for (auto m : members[g]) {
mergeuntil[m] = a;
}
}
dfs2(0);
bool plus1 = false;
int lcapenible = -1;
for (int i = 1; i < N; i++) {
if (root(i) == i) {
if (lcapenible == -1) lcapenible = i;
else lcapenible = lca(lcapenible, i);
penible[i] = true;
}
}
for (int i = 1; i < N; i++) {
if (penible[i]) {
if (lca(lcapenible, i) == i) {
plus1 = true;
break;
}
}
}
int ans = dfs3(0)+plus1;
cout << ((ans+1)>>1) << '\n';
}
Compilation message
mergers.cpp: In member function 'bool lcatreeitem::operator<(const lcatreeitem&) const':
mergers.cpp:26:9: error: 'assert' was not declared in this scope
26 | assert(i >= 0 && other.i >= 0 && i < N && other.i < N);
| ^~~~~~
mergers.cpp:5:1: note: 'assert' is defined in header '<cassert>'; did you forget to '#include <cassert>'?
4 | #include <numeric>
+++ |+#include <cassert>
5 | using namespace std;
mergers.cpp:27:9: error: 'else' without a previous 'if'
27 | else return dist[i] < dist[other.i];
| ^~~~
mergers.cpp: In function 'int main()':
mergers.cpp:125:5: error: 'assert' was not declared in this scope
125 | assert(N < 5e5);
| ^~~~~~
mergers.cpp:125:5: note: 'assert' is defined in header '<cassert>'; did you forget to '#include <cassert>'?