#include <bits/stdc++.h>
using namespace std;
template<typename T>
void out(T x) { cout << x << endl; exit(0); }
#define watch(x) cout << (#x) << " is " << (x) << endl
using ll = long long;
struct dsu0 {
vector<int> par, siz;
int n;
int cc;
int largest;
void init(int n) {
assert(n>0);
this->n=n;
cc=n;
par.resize(n+10);siz.resize(n+10);
for (int i=0; i<n; i++) par[i]=i,siz[i]=1;
largest=1;
}
int parent(int x) {
assert(x>=0 && x<n);
return par[x]==x?x:par[x]=parent(par[x]);
}
bool join(int x, int y) {
x=parent(x);y=parent(y);
if (x==y) return false;
cc--;
if (siz[x]<siz[y]) swap(x,y);
siz[x]+=siz[y];par[y]=x;
largest=max(largest,siz[x]);
return true;
}
};
const int maxn = 5e5 + 10;
int n, k;
vector<int> g[maxn];
int a[maxn];
const int LOG = 20;
int tin[maxn];
int tout[maxn];
int cloc = 0;
int dep[maxn];
int par[LOG+1][maxn];
int lca(int u, int v) {
if (dep[u]>dep[v]) swap(u,v);
// u
// v
int dx = dep[v]-dep[u];
for (int j=LOG-1; j>=0; j--) {
if (dx>>j&1) {
v = par[j][u];
}
}
if (u==v) return v;
for (int j=LOG-1; j>=0; j--) {
if (par[j][u]!=par[j][v]) {
u=par[j][u];
v=par[j][v];
}
}
return par[0][v];
}
vector<int> bycolor[maxn];
vector<int> nodes[maxn];
void dfs(int at, int p) {
tin[at] = cloc++;
bycolor[a[at]].push_back(tin[at]);
for (int j=1; j<LOG; j++) {
par[j][at] = par[j-1][par[j-1][at]];
}
for (int to: g[at]) {
if (to == p) continue;
par[0][to] = at;
dep[to] = 1+dep[at];
dfs(to, at);
}
tout[at] = cloc++;
}
dsu0 dsu;
void dfs2(int at, int p, int c) {
for (int to: g[at]) {
if (to == p) continue;
auto iter = lower_bound(bycolor[c].begin(), bycolor[c].end(), tin[to]);
if (iter != bycolor[c].end() && *iter <= tout[to]) {
dsu.join(at, to);
//cout<<at+1<<"-->"<<to+1<<endl;
dfs2(to, at, c);
}
}
}
set<int> G[maxn];
void dfs3(int at, int p) {
for (int to: g[at]) {
if (to == p) continue;
int pat = dsu.parent(at);
int pto = dsu.parent(to);
if (pat!=pto) {
G[pat].insert(pto);
G[pto].insert(pat);
}
dfs3(to, at);
}
}
int main() {
ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);
cin>>n>>k;
for (int i=0; i<n-1; i++) {
int u,v; cin>>u>>v;
--u; --v;
g[u].push_back(v);
g[v].push_back(u);
}
for (int i=0; i<n; i++) {
cin>>a[i];
--a[i];
nodes[a[i]].push_back(i);
}
dfs(0,-1);
dsu.init(n);
// for each color, join them into a supernode
for (int j=0; j<k; j++) {
int mid = nodes[j][0];
for (int x: nodes[j]) {
mid = lca(mid, x);
}
dfs2(mid,par[0][mid],j);
}
set<int> st;
for (int i=0; i<n; i++) {
st.insert(dsu.parent(i));
}
if ((int)st.size()==1) out(0);
dfs3(0,-1);
int leaves = 0;
for (int x: st) {
if ((int)G[x].size() == 1) leaves++;
}
int ans = (leaves+1)/2;
out(ans);
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
38 ms |
59372 KB |
Output is correct |
2 |
Incorrect |
36 ms |
59244 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
38 ms |
59372 KB |
Output is correct |
2 |
Incorrect |
36 ms |
59244 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
38 ms |
59372 KB |
Output is correct |
2 |
Incorrect |
36 ms |
59244 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
276 ms |
75492 KB |
Output is correct |
2 |
Correct |
306 ms |
95076 KB |
Output is correct |
3 |
Incorrect |
41 ms |
59884 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
38 ms |
59372 KB |
Output is correct |
2 |
Incorrect |
36 ms |
59244 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |