#include <bits/stdc++.h>
#define ll long long
#define ld long double
#define sp ' '
#define en '\n'
#define smin(a, b) a = min(a, b)
#define smax(a, b) a = max(a, b)
using namespace std;
const int N = 5e5 + 2;
const int M = 5e5 + 2;
int mod = 1000000007;
vector<int> g[N];
int lift[N][19], in[N], out[N], tsz;
void Dfs(int s, int e) {
in[s] = ++tsz;
lift[s][0] = e;
for (int i = 0; i < 19; i++) lift[s][i] = lift[lift[s][i - 1]][i - 1];
for (int u : g[s]) {
if (u != e) {
Dfs(u, s);
}
}
out[s] = tsz;
}
bool Ancestor(int a, int b) {
return in[a] <= in[b] && out[b] <= out[a];
}
int Lca(int a, int b) {
if (Ancestor(a, b)) return a;
if (Ancestor(b, a)) return b;
for (int i = 18; i >= 0; i--) {
if (lift[a][i] && !Ancestor(lift[a][i], b)) a = lift[a][i];
}
return lift[a][0];
}
int sum[N], sz[N], ok[N], a[N], cnt[N];
void Dfs1(int s, int e) {
sz[s] = 1;
for (int u : g[s]) {
if (u != e) {
Dfs1(u, s);
sz[s] += sz[u];
sum[s] += sum[u];
}
}
ok[s] = sum[s] == sz[s];
}
void Dfs2(int s, int e, int p) {
if (ok[s] == 1) {
if (p != -1) {
cnt[p]++;
cnt[s]++;
}
p = s;
}
for (int u : g[s]) {
if (u != e) {
if (ok[s]) Dfs2(u, s, s);
else Dfs2(u, s, p);
}
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int n, k;
cin >> n >> k;
for (int i = 0; i < n - 1; i++) {
int a, b;
cin >> a >> b;
g[a - 1].push_back(b - 1);
g[b - 1].push_back(a - 1);
}
Dfs(0, -1);
vector<vector<int>> pos(k);
for (int i = 0; i < n; i++) {
cin >> a[i];
a[i] -= 1;
pos[a[i]].push_back(i);
}
for (int i = 0; i < k; i++) {
if (pos[i].empty()) continue;
int tr = pos[i][0];
for (int j = 0; j < pos[i].size(); j++) {
tr = Lca(tr, pos[i][j]);
}
sum[tr] += pos[i].size();
}
Dfs1(0, -1);
Dfs2(0, -1, -1);
int ans = 0;
for (int i = 0; i < n; i++) if (cnt[i] == 1) ans++;
cout << (ans + 1) / 2 << en;
return 0;
}
Compilation message
mergers.cpp: In function 'int main()':
mergers.cpp:85:27: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
85 | for (int j = 0; j < pos[i].size(); j++) {
| ~~^~~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
7 ms |
12116 KB |
Output is correct |
2 |
Incorrect |
8 ms |
12116 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
7 ms |
12116 KB |
Output is correct |
2 |
Incorrect |
8 ms |
12116 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
7 ms |
12116 KB |
Output is correct |
2 |
Incorrect |
8 ms |
12116 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
61 ms |
25508 KB |
Output is correct |
2 |
Correct |
78 ms |
31204 KB |
Output is correct |
3 |
Incorrect |
8 ms |
12628 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
7 ms |
12116 KB |
Output is correct |
2 |
Incorrect |
8 ms |
12116 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |