#include <bits/stdc++.h>
using namespace std;
const int maxn = 500010;
int n, k;
int cset[maxn]; //store for the states
int mstate[maxn]; //store for the nodes
int croot[maxn];
#define pii pair<int, int>
vector<int> adj[maxn];
int findset(int u) {
if (cset[u] == u) return u;
return cset[u] = findset(cset[u]);
}
void unionset(int a, int b) {
a = findset(a);
b = findset(b);
if (a == b) return;
if (rand()%2) cset[a] = b;
else cset[b] = a;
}
int getcomp(int u) {
return findset(mstate[u]);
}
int par[20][maxn];
int dep[maxn];
int walk(int u, int k) {
for (int i = 0; i < 19; i++) {
if (k & (1 << i)) {
u = par[i][u];
}
}
return u;
}
int lca(int a, int b) {
if (dep[a] < dep[b]) swap(a, b);
a = walk(a, dep[a] - dep[b]);
if (a == b) return a;
for (int i = 19; i >= 0; i--) {
if (par[i][a] != par[i][b]) {
a = par[i][a];
b = par[i][b];
}
}
return par[0][a];
}
void predfs(int u, int p = -1) {
dep[u] = p == -1 ? 1 : dep[p]+1;
for (int v : adj[u]) {
if (v == p) continue;
predfs(v, u);
}
}
set<pii> curin;
vector<int> myremos[maxn];
void dfs(int u, int p = -1) {
for (int v : adj[u]) {
if (v == p) continue;
dfs(v, u);
}
if (curin.size()) {
int cv = (*(curin.begin())).second;
unionset(cv, mstate[u]);
}
curin.insert(pii(dep[mstate[u]], mstate[u]));
int cv = (*(curin.begin())).second;
for (int v : myremos[u]) {
unionset(cv, v);
curin.erase(curin.find(pii(dep[u], v)));
}
}
int deg[maxn];
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cin >> n >> k;
int u, v;
for (int i = 0; i < n-1; i++) {
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
for (int i = 1; i <= k; i++) {
cset[i] = i;
}
for (int i = 1; i <= n; i++) {
cin >> mstate[i];
}
predfs(1);
//now we are ready to union components
for (int i= 1; i <= n; i++) {
if (croot[mstate[i]] == 0) croot[mstate[i]] = i;
else {
croot[mstate[i]] = lca(croot[mstate[i]], i);
}
}
for (int i = 1; i <= n; i++) {
myremos[croot[i]].push_back(i);
}
for (int i = 1; i <= k; i++) {
reverse(myremos[i].begin(), myremos[i].end());
}
for (int i = 1; i <= n; i++) {
for (int j : adj[i]) {
if (getcomp(i) != getcomp(j)) {
deg[getcomp(i)]++;
deg[getcomp(j)]++;
}
}
}
int nc = 0;
for (int i = 1; i <= k; i++) {
deg[i] /= 2;
if (deg[i] == 1) nc++;
}
cout << (nc+1)/2 << endl;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
24 ms |
23936 KB |
Output is correct |
2 |
Incorrect |
22 ms |
23928 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
24 ms |
23936 KB |
Output is correct |
2 |
Incorrect |
22 ms |
23928 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
24 ms |
23936 KB |
Output is correct |
2 |
Incorrect |
22 ms |
23928 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
104 ms |
28952 KB |
Output is correct |
2 |
Correct |
102 ms |
32504 KB |
Output is correct |
3 |
Incorrect |
29 ms |
24184 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
24 ms |
23936 KB |
Output is correct |
2 |
Incorrect |
22 ms |
23928 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |