#include <bits/stdc++.h>
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("avx,avx2,fma")
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
#define MAX 505050
#define MAXS 20
#define INF 1000000000000000001
#define bb ' '
#define ln '\n'
#define Ln '\n'
#define fmax(a, b) (a) = (max((a), (b))
vector<int> adj[MAX];
int sp[MAX][MAXS];
int dep[MAX] = { 0, 1 };
vector<int> locs[MAX];
void dfs(int x, int p = 0) {
sp[x][0] = p;
int i;
for (i = 1; i < MAXS; i++) sp[x][i] = sp[sp[x][i - 1]][i - 1];
for (auto v : adj[x]) if (v != p) {
dep[v] = dep[x] + 1;
dfs(v, x);
}
}
int lca(int u, int v) {
int i;
if (dep[u] != dep[v]) {
if (dep[u] > dep[v]) swap(u, v);
int d = dep[v] - dep[u];
for (i = 0; i < MAXS; i++) if ((d >> i) & 1) v = sp[v][i];
}
if (u == v) return u;
for (i = MAXS - 1; i >= 0; i--) if (sp[u][i] != sp[v][i]) u = sp[u][i], v = sp[v][i];
return sp[u][0];
}
int p[MAX];
int psum[MAX];
int col[MAX];
int find(int x) {
if (p[x] == x) return x;
return p[x] = find(p[x]);
}
void uni(int u, int v) {
u = find(u);
v = find(v);
if (u == v) return;
if (adj[u].size() < adj[v].size()) swap(u, v);
for (auto x : adj[v]) adj[u].push_back(v);
p[v] = u;
}
void calc(int x, int p = 0) {
for (auto v : adj[x]) if (v != p) {
calc(v, x);
psum[x] += psum[v];
}
}
int nleaf(int x, int p = 0) {
int cnt = 0;
int sum = 0;
for (auto v : adj[x]) {
v = find(v);
if (v == x) continue;
cnt++;
if (v == p) continue;
sum += nleaf(v, x);
}
if (cnt == 1) sum++;
return sum;
}
signed main() {
ios::sync_with_stdio(false), cin.tie(0);
int N, K;
cin >> N >> K;
int i;
int a, b;
for (i = 1; i < N; i++) {
cin >> a >> b;
adj[a].push_back(b);
adj[b].push_back(a);
}
dfs(1);
for (i = 1; i <= N; i++) {
cin >> col[i];
locs[col[i]].push_back(i);
}
for (i = 1; i <= K; i++) {
if (locs[i].size() > 1) {
int j;
for (j = 1; j < locs[i].size(); j++) {
int l = lca(locs[i][j], locs[i][j - 1]);
psum[locs[i][j]]++;
psum[locs[i][j - 1]]++;
psum[l] -= 2;
}
}
}
for (i = 1; i <= N; i++) p[i] = i;
calc(1);
for (i = 2; i <= N; i++) if (psum[i]) uni(i, sp[i][0]);
cout << (nleaf(find(1)) + 1) / 2;
}
Compilation message
mergers.cpp: In function 'void uni(int, int)':
mergers.cpp:53:12: warning: unused variable 'x' [-Wunused-variable]
53 | for (auto x : adj[v]) adj[u].push_back(v);
| ^
mergers.cpp: In function 'int main()':
mergers.cpp:94:18: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
94 | for (j = 1; j < locs[i].size(); j++) {
| ~~^~~~~~~~~~~~~~~~
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
12 ms |
24056 KB |
Output is correct |
2 |
Incorrect |
13 ms |
24020 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
12 ms |
24056 KB |
Output is correct |
2 |
Incorrect |
13 ms |
24020 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
12 ms |
24056 KB |
Output is correct |
2 |
Incorrect |
13 ms |
24020 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
66 ms |
38116 KB |
Output is correct |
2 |
Correct |
66 ms |
39700 KB |
Output is correct |
3 |
Incorrect |
17 ms |
24476 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
12 ms |
24056 KB |
Output is correct |
2 |
Incorrect |
13 ms |
24020 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |