#include <bits/stdc++.h>
#define ll long long
#define ld long double
#define sp ' '
#define en '\n'
#define smin(a, b) a = min(a, b)
#define smax(a, b) a = max(a, b)
using namespace std;
const int N = 5e5 + 2;
const int M = 5e5 + 2;
int mod = 1000000007;
vector<int> g[N];
int lift[N][19], in[N], out[N], tsz;
void Dfs(int s, int e) {
in[s] = ++tsz;
lift[s][0] = e;
for (int i = 0; i < 19; i++) lift[s][i] = lift[lift[s][i - 1]][i - 1];
for (int u : g[s]) {
if (u != e) {
Dfs(u, s);
}
}
out[s] = tsz;
}
bool Ancestor(int a, int b) {
return in[a] <= in[b] && out[b] <= out[a];
}
int Lca(int a, int b) {
if (Ancestor(a, b)) return a;
if (Ancestor(b, a)) return b;
for (int i = 18; i >= 0; i--) {
if (lift[a][i] && !Ancestor(lift[a][i], b)) a = lift[a][i];
}
return lift[a][0];
}
int sum[N], sz[N], ok[N], a[N];
void Dfs1(int s, int e) {
sz[s] = 1;
for (int u : g[s]) {
if (u != e) {
Dfs1(u, s);
sz[s] += sz[u];
sum[s] += sum[u];
}
}
ok[s] = sum[s] == sz[s];
}
int dp[N][3];
int F(int s) {
return (dp[s][1] + 1) / 2 + dp[s][2];
}
void Dfs2(int s, int e) {
dp[s][0] = ok[s];
bool f = 1;
for (int u : g[s]) {
if (u != e) {
Dfs2(u, s);
dp[s][0] += dp[u][0];
if (F(u) <= dp[u][0]) {
f = 0;
dp[s][1] += dp[u][1];
dp[s][2] += dp[u][2];
}
else dp[s][1] += dp[u][0];
}
}
if (a[s] != a[0]) dp[s][2] += f;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int n, k;
cin >> n >> k;
for (int i = 0; i < n - 1; i++) {
int a, b;
cin >> a >> b;
g[a - 1].push_back(b - 1);
g[b - 1].push_back(a - 1);
}
Dfs(0, -1);
vector<vector<int>> pos(k);
for (int i = 0; i < n; i++) {
cin >> a[i];
a[i] -= 1;
pos[a[i]].push_back(i);
}
for (int i = 0; i < k; i++) {
if (pos[i].empty()) continue;
int tr = pos[i][0];
for (int j = 0; j < pos[i].size(); j++) {
tr = Lca(tr, pos[i][j]);
}
sum[tr] += pos[i].size();
}
Dfs1(0, -1);
Dfs2(0, -1);
int x, y;
x = y = 0;
for (int j : g[0]) {
int u = (x + dp[j][0] + 1) / 2 + y;
int v = (x + dp[j][1] + 1) / 2 + y + dp[j][2];
if (u < v) {
x += dp[j][0];
}
else {
x += dp[j][1];
y += dp[j][2];
}
}
cout << (x + 1) / 2 + y << en;
return 0;
}
Compilation message
mergers.cpp: In function 'int main()':
mergers.cpp:91:27: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
91 | for (int j = 0; j < pos[i].size(); j++) {
| ~~^~~~~~~~~~~~~~~
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
6 ms |
12116 KB |
Output is correct |
2 |
Incorrect |
7 ms |
12032 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
6 ms |
12116 KB |
Output is correct |
2 |
Incorrect |
7 ms |
12032 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
6 ms |
12116 KB |
Output is correct |
2 |
Incorrect |
7 ms |
12032 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
66 ms |
28196 KB |
Output is correct |
2 |
Correct |
88 ms |
33736 KB |
Output is correct |
3 |
Incorrect |
8 ms |
12600 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
6 ms |
12116 KB |
Output is correct |
2 |
Incorrect |
7 ms |
12032 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |