#include<bits/stdc++.h>
#define ll long long
#define int long long
const int nmax = 5e5 + 5, N = 1e5;
const ll oo = 1e18 + 1, base = 311;
const int lg = 19, M = 10;
const ll mod = 998244353, mod2 = 1e9 + 5277;
#define pii pair<int, int>
#define fi first
#define se second
#define endl "\n"
#define debug(a, n) for(int i = 1; i <= n; ++i) cout << a[i] << ' '; cout << "\n";
using namespace std;
int n, k, a[nmax];
vector<int> adj[nmax], one[nmax];
int st[nmax], en[nmax], tour[nmax], sc = 0;
int h[nmax], up[nmax][lg + 1], par[nmax];
void dfs_1(int u, int p){
tour[++sc] = u;
st[u] = sc;
for(auto v : adj[u]){
if(v == p) continue;
h[v] = h[u] + 1;
up[v][0] = u;
for(int j = 1; j <= lg; ++j) up[v][j] = up[up[v][j - 1]][j - 1];
dfs_1(v, u);
}
en[u] = sc;
}
int lca(int u, int v){
if(h[u] != h[v]){
if(h[u] < h[v]) swap(u, v);
int k = h[u] - h[v];
for(int j = 0; j <= lg; ++j){
if(k >> j & 1){
u = up[u][j];
}
}
}
if(u == v) return u;
for(int j = __lg(h[u]); j >= 0; --j){
if(up[u][j] != up[v][j]){
u = up[u][j];
v = up[v][j];
}
}
return up[u][0];
}
int c[nmax];
void dfs_2(int u, int p){
for(auto v : adj[u]){
if(v == p) continue;
dfs_2(v, u);
c[u] += c[v];
}
}
int r[nmax];
int get(int u){
return r[u] ? r[u] = get(r[u]) : u;
}
void Union(int u, int v){
u = get(u);
v = get(v);
if(u != v){
r[u] = v;
}
}
vector<int> gg[nmax];
int cnt = 0;
void dfs_3(int u, int p){
bool ok = 1;
for(auto v : gg[u]){
if(v == p) continue;
dfs_3(v, u);
ok = 0;
}
if(ok) ++cnt;
}
main(){
ios_base::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
// freopen("code.inp", "r", stdin);
// freopen("code.out", "w", stdout);
cin >> n >> k;
for(int i = 1, u, v; i < n;++i){
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
for(int i =1; i <= n; ++i){
cin >> a[i];
one[a[i]].push_back(i);
}
dfs_1(1, -1);
vector<int> tmp, two;
for(int i = 1; i <= k; ++i){
sort(one[i].begin(), one[i].end(), [](int a, int b){
return st[a] < st[b];
});
tmp = one[i];
for(int j = 1; j < one[i].size(); ++j){
tmp.push_back(lca(one[i][j - 1], one[i][j]));
}
sort(tmp.begin(), tmp.end());
tmp.erase(unique(tmp.begin(), tmp.end()), tmp.end());
sort(tmp.begin(), tmp.end(), [](int a, int b){
return st[a] < st[b];
});
for(int j = 0; j < tmp.size(); ++j){
int u = tmp[j];
if(j == 0) two.push_back(u);
else{
while(1){
int x = two.back();
if(st[x] <= st[u] && en[u] <= en[x])break;
two.pop_back();
}
int x = two.back();
two.push_back(u);
c[x]--;
c[u]++;
}
}
two.clear();
}
dfs_2(1, -1);
for(int i = 2; i <= n; ++i){
int x = up[i][0];
if(c[i]){
Union(i, x);
}
}
for(int i = 1; i <= n; ++i){
int u = get(i);
for(auto v : adj[i]){
int x = get(i), y = get(v);
if(x != y){
gg[x].push_back(y);
}
}
}
dfs_3(get(1), -1);
cout << (cnt + 1) / 2;
}
/*
5 4
1 2
2 3
3 4
3 5
1 2 1 3 4
*/
Compilation message
mergers.cpp:84:1: warning: ISO C++ forbids declaration of 'main' with no type [-Wreturn-type]
84 | main(){
| ^~~~
mergers.cpp: In function 'int main()':
mergers.cpp:106:26: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
106 | for(int j = 1; j < one[i].size(); ++j){
| ~~^~~~~~~~~~~~~~~
mergers.cpp:115:26: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
115 | for(int j = 0; j < tmp.size(); ++j){
| ~~^~~~~~~~~~~~
mergers.cpp:140:13: warning: unused variable 'u' [-Wunused-variable]
140 | int u = get(i);
| ^
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
13 ms |
35676 KB |
Output is correct |
2 |
Correct |
13 ms |
35772 KB |
Output is correct |
3 |
Correct |
15 ms |
35636 KB |
Output is correct |
4 |
Correct |
15 ms |
35672 KB |
Output is correct |
5 |
Incorrect |
15 ms |
35544 KB |
Output isn't correct |
6 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
13 ms |
35676 KB |
Output is correct |
2 |
Correct |
13 ms |
35772 KB |
Output is correct |
3 |
Correct |
15 ms |
35636 KB |
Output is correct |
4 |
Correct |
15 ms |
35672 KB |
Output is correct |
5 |
Incorrect |
15 ms |
35544 KB |
Output isn't correct |
6 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
13 ms |
35676 KB |
Output is correct |
2 |
Correct |
13 ms |
35772 KB |
Output is correct |
3 |
Correct |
15 ms |
35636 KB |
Output is correct |
4 |
Correct |
15 ms |
35672 KB |
Output is correct |
5 |
Incorrect |
15 ms |
35544 KB |
Output isn't correct |
6 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
77 ms |
63172 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
13 ms |
35676 KB |
Output is correct |
2 |
Correct |
13 ms |
35772 KB |
Output is correct |
3 |
Correct |
15 ms |
35636 KB |
Output is correct |
4 |
Correct |
15 ms |
35672 KB |
Output is correct |
5 |
Incorrect |
15 ms |
35544 KB |
Output isn't correct |
6 |
Halted |
0 ms |
0 KB |
- |