#include <bits/stdc++.h>
using namespace std;
struct disjoint_set{
vector<int> lab;
disjoint_set(int n) : lab(n, -1) {}
int root(int u){
return lab[u] < 0 ? u : (lab[u] = root(lab[u]));
}
bool join(int u, int v){
u = root(u);
v = root(v);
if(u == v) return false;
if(lab[u] > lab[v]) swap(u, v);
lab[u] += lab[v];
lab[v] = u;
return true;
}
bool same_set(int u, int v){ return root(u) == root(v); }
};
int main(){
ios_base::sync_with_stdio(0); cin.tie(0);
int N, K;
cin >> N >> K;
vector<vector<int>> adj(N);
for(int i = 1; i < N; ++i){
int u, v;
cin >> u >> v;
--u, --v;
adj[u].emplace_back(v);
adj[v].emplace_back(u);
}
vector<int> S(N), cnt(K);
for(int i = 0; i < N; ++i){
cin >> S[i];
--S[i];
++cnt[S[i]];
}
vector<map<int, int>> mp(N);
for(int i = 0; i < N; ++i){
if(cnt[S[i]] > 1) ++mp[i][S[i]];
}
vector<int> par(N);
disjoint_set dsu(N);
function<void(int, int)> dfs = [&](int u, int p){
for(auto v : adj[u]) if(v != p){
par[v] = u;
dfs(v, u);
if(!mp[v].empty()){
dsu.join(u, v);
}
if((int)mp[u].size() < (int)mp[v].size()){
swap(mp[u], mp[v]);
}
for(auto [key, value] : mp[v]){
int cur = (mp[u][key] += value);
if(cur == cnt[key]) mp[u].erase(mp[u].find(key));
}
}
};
dfs(0, -1);
vector<int> deg(N);
for(int i = 1; i < N; ++i){
if(!dsu.same_set(i, par[i])){
int u = dsu.root(i), v = dsu.root(par[i]);
++deg[u];
++deg[v];
}
}
int cnt_leaves = 0;
for(int i = 0; i < N; ++i) if(dsu.root(i) == i) cnt_leaves += (deg[i] == 1);
cout << (cnt_leaves + 1) / 2 << '\n';
return 0;
}
//subproblem : https://oj.uz/problem/view/BOI15_net
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |