#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define f first
#define s second
//#define endl '\n'
#define all(x) begin(x), end(x)
int n,k,a,b,ts;
vector<int> grph[200069],towns[200069];
int col[200069],cnt[200069],curcnt[200069],done[200069];
int sz[200069],par[200069];
bool cent[200069];
void clear(int x, int lst=-1) {
curcnt[col[x]]=0;
done[col[x]]=0;
for (auto i : grph[x]) {
if (i==lst||cent[i]) {
continue;
}
clear(i,x);
}
}
int dfs(int x, int lst) {
sz[x]=1;
par[x]=lst;
curcnt[col[x]]++;
for (auto i : grph[x]) {
if (i==lst||cent[i]) {
continue;
}
sz[x]+=dfs(i,x);
}
return sz[x];
}
int fndcent(int x, int lst=-1) {
for (auto i : grph[x]) {
if (i==lst||cent[i]) {
continue;
}
if (sz[i]*2>ts) {
return fndcent(i,x);
}
}
return x;
}
int solve(int x) {
clear(x);
ts=dfs(x,x);
x=fndcent(x);
//x is now the centroid
stack<int> sk;
sk.push(col[x]);
int ans=0;
while (sk.size()) {
int tmp=sk.top();
sk.pop();
if (done[tmp]) {
continue;
}
done[tmp]=true;
if (curcnt[tmp]!=cnt[tmp]) {
ans=INT_MAX;
break;
}
ans++;
for (auto i:towns[tmp]) {
sk.push(col[par[i]]);
}
}
cent[x]=true;
for (auto i:grph[x]) {
if (!cent[i]) {
ans=min(ans,solve(i));
}
}
return ans;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cin>>n>>k;
for (int i = 0; i < n-1; i++) {
cin>>a>>b;
a--;b--;
grph[a].push_back(b);
grph[b].push_back(a);
}
for (int i = 0; i < n; i++) {
cin>>col[i];
col[i]--;
cnt[col[i]]++;
}
cout<<solve(0)<<endl;
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
5 ms |
9684 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
5 ms |
9684 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
396 ms |
35492 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
5 ms |
9684 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |