#include <bits/stdc++.h>
#define ll long long
#define pb push_back
#define mp make_pair
#define sz(x) int(x.size())
using namespace std;
const int maxn = 100100;
int n, k;
vector<int>g[maxn], towns[maxn];
int city[maxn], cnt[maxn];
int usize[maxn], uparent[maxn];
int ufind(int node) {
while(node != uparent[node]) {
node = uparent[node];
}
return node;
}
bool unite(int a, int b) {
a = ufind(a);
b = ufind(b);
if(a == b) return false;
if(usize[a] > usize[b]) {
uparent[b] = a;
usize[a] += usize[b];
}
else {
uparent[a] = b;
usize[b] += usize[a];
}
return true;
}
bool deleted[maxn];
int sbt_sz[maxn];
int par[maxn];
int result;
vector<int>curr_nodes;
void get_sizes(int node, int p) {
sbt_sz[node] = 1;
par[node] = p;
curr_nodes.pb(node);
for(int i:g[node]) {
if(deleted[i] || i == p) continue;
get_sizes(i, node);
sbt_sz[node] += sbt_sz[i];
}
}
int get_centroid(int node) {
int min_sbtsize = INT_MAX;
int centroid = -1;
queue<int>q;
q.push(node);
while(!q.empty()) {
int curr = q.front();
q.pop();
int max_sbtsize = 0;
for(int i:g[curr]) {
if(deleted[i]) continue;
if(i == par[curr]) {
max_sbtsize = max(max_sbtsize, sbt_sz[node] - sbt_sz[curr]);
continue;
}
max_sbtsize = max(max_sbtsize, sbt_sz[i]);
q.push(i);
}
if(max_sbtsize < min_sbtsize) {
min_sbtsize = max_sbtsize;
centroid = curr;
}
}
return centroid;
}
void solve(int root) {
get_sizes(root, -1); // Preprocess and find the subtree sizes of each node
root = get_centroid(root); // Find the centroid in this tree
for(int i:curr_nodes) {
cnt[city[i]]++;
}
bool check = (cnt[city[root]] == sz(towns[city[root]]));
int moves = 0;
if(check) {
queue<int>q;
for(int i:towns[city[root]]) {
q.push(i);
}
while(!q.empty()) {
int curr = q.front();
q.pop();
if(!check) break;
if(par[curr] == -1) continue;
if(unite(city[curr], city[par[curr]])) {
moves++;
check &= (cnt[city[par[curr]]] == sz(towns[city[par[curr]]]));
if(check) {
for(int i:towns[par[curr]]) {
q.push(i);
}
}
}
}
if(check) result = min(result, moves);
}
for(int i:curr_nodes) {
usize[city[i]] = 1;
uparent[city[i]] = city[i];
cnt[city[i]] = 0;
}
curr_nodes.clear();
deleted[root] = true;
for(int i:g[root]) {
if(!deleted[i]) {
solve(i);
}
}
}
int main() {
cin>>n>>k;
int a, b;
for(int i=0;i<n-1;i++) {
cin>>a>>b;
g[a].pb(b);
g[b].pb(a);
}
for(int i=1;i<=n;i++) {
cin>>city[i];
towns[city[i]].pb(i);
}
result = INT_MAX;
for(int i=1;i<=k;i++) {
usize[i] = 1;
uparent[i] = i;
}
solve(1);
cout<<result;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
7 ms |
5120 KB |
Output is correct |
2 |
Correct |
7 ms |
4992 KB |
Output is correct |
3 |
Correct |
7 ms |
4992 KB |
Output is correct |
4 |
Incorrect |
7 ms |
4992 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
7 ms |
5120 KB |
Output is correct |
2 |
Correct |
7 ms |
4992 KB |
Output is correct |
3 |
Correct |
7 ms |
4992 KB |
Output is correct |
4 |
Incorrect |
7 ms |
4992 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
15 ms |
9984 KB |
Execution killed with signal 11 (could be triggered by violating memory limits) |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
7 ms |
5120 KB |
Output is correct |
2 |
Correct |
7 ms |
4992 KB |
Output is correct |
3 |
Correct |
7 ms |
4992 KB |
Output is correct |
4 |
Incorrect |
7 ms |
4992 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |