#include <bits/stdc++.h>
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2")
using namespace std;
#define int long long
#define pii pair<int,int>
#define ff first
#define ss second
#define sp << " " <<
#define all(cont) cont.begin(),cont.end()
#define vi vector<int>
const int inf = 1e18,N = 5e5+1,MOD = 998244353;
vi edges[N];
int timer = 1;
vi tin(N),tout(N);
int up[N][20];
vi dp(N,0);
void dfs(int node,int p) {
tin[node] = timer++;
up[node][0] = p;
for (int i = 1;i<20;i++) up[node][i] = up[up[node][i-1]][i-1];
for (auto it : edges[node]) {
if (it == p) continue;
dfs(it,node);
}
tout[node] = timer-1;
}
bool anc(int a,int b) {
return tin[a] <= tin[b] && tout[a] >= tout[b];
}
int lca(int a,int b) {
if (anc(a,b)) return a;
if (anc(b,a)) return b;
for (int i = 19;i>=0;i--) if (!anc(up[a][i],b)) a = up[a][i];
return up[a][0];
}
vi act;
void dfs2(int node,int p) {
for (auto it : edges[node]) {
if (it == p) continue;
dfs2(it,node);
dp[node]+=dp[it];
}
}
vi v;
void dfs3(int node,int p) {
if (!dp[node] && node > 1) v.push_back(node);
for (auto it : edges[node]) {
if (it == p) continue;
dfs3(it,node);
}
}
void solve() {
int n,q;
cin >> n >> q;
for (int i = 1;i<n;i++) {
int a,b;
cin >> a >> b;
edges[a].push_back(b);
edges[b].push_back(a);
}
dfs(1,1);
vi c(n+1);
vi buck[q+1];
for (int i = 1;i<=n;i++) {
cin >> c[i];
buck[c[i]].push_back(i);
}
for (int i = 1;i<=q;i++) {
if (buck[i].size() < 2) continue;
sort(all(buck[i]),[&](int x,int y) {return tin[x] < tin[y];});
int sz = buck[i].size();
for (int j = 0;j<sz;j++) {
int a = buck[i][j];
int b = buck[i][(j+1)%sz];
int l = lca(a,b);
if (l != a && l != b) {
dp[a]++;
dp[b]++;
dp[l]-=2;
}
else {
if (l == b) swap(a,b);
dp[b]++;
dp[l]--;
}
}
}
dfs2(1,1);
dfs3(1,1);
vi cnt(n+1,0);
int leaf = 0;
stack<int> stk;
int sz = v.size();
for (int i = 0;i<sz;i++) {
v.push_back(lca(v[i],v[(i+1)%sz]));
}
sort(all(v),[&](int x,int y) {
return tin[x] < tin[y];
});
v.erase(unique(all(v)),v.end());
for (auto it : v) {
while (!stk.empty() && !anc(stk.top(),it)) stk.pop();
if (!stk.empty()) {
cnt[stk.top()]++;
cnt[it]++;
}
stk.push(it);
}
for (auto it : v) if (cnt[it] == 1) leaf++;
cout << (leaf+1)/2 << '\n';
}
int32_t main() {
ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
#ifdef Dodi
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
#endif
int t = 1;
//cin >> t;
while (t --> 0) solve();
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |