This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
using namespace std;
const int mxN=200005;
const int mxK=25;
#define pii pair<int, int>
#define pll pair<ll, ll>
#define fi first
#define se second
#define all(x) x.begin(), x.end()
typedef long long ll;
const ll MOD=1000000007;
const ll INF=1e18;
int N, K;
vector <int> v[mxN];
vector <int> ct[mxN];
int C[mxN];
int dep[mxN], par[mxN], sps[mxN][mxK];
int root[mxN];
int in[mxN], out[mxN], iidx;
set <pii> s;
struct cmp1{
bool operator()(const int a, const int b)const {
return in[a]<in[b];
}
};
set <int, cmp1> et[mxN];
set <int> np[mxN];
int cnt[mxN];
int ans=mxN;
void dfs1(int now, int pre=-1)
{
in[now]=++iidx;
for(int i=1;i<=19;i++) sps[now][i]=sps[sps[now][i-1]][i-1];
for(int nxt : v[now]) if(nxt!=pre)
{
dep[nxt]=dep[now]+1;
sps[nxt][0]=par[nxt]=now;
dfs1(nxt, now);
}
out[now]=iidx;
}
int lca(int a, int b)
{
if(dep[a]<dep[b]) swap(a, b);
for(int i=19;i>=0;i--)
{
if(dep[a]-(1<<i)>=dep[b]) a=sps[a][i];
}
if(a==b) return a;
for(int i=19;i>=0;i--)
{
if(sps[a][i]!=sps[b][i]) a=sps[a][i], b=sps[b][i];
}
return sps[a][0];
}
void onion(int c1, int c2)
{
int r2=root[c2];
bool ok=false;
for(int e : et[c1])
{
auto it=et[c2].lower_bound(e);
if(it!=et[c2].end() && *it!=r2) ok=true;
}
if(ok && s.find(pii(et[c2].size(), c2))!=s.end())
{
s.erase(pii(et[c2].size(), c2));
np[c1].erase(par[root[c1]]);
np[c2].erase(par[root[c2]]);
root[c2]=lca(root[c1], root[c2]);
np[c2].insert(par[root[c2]]);
for(int e : et[c1])
{
C[e]=c2;
if(np[c2].find(e)!=np[c2].end()) np[c2].erase(e);
}
for(int e : np[c1])
{
if(et[c2].find(e)==et[c2].end()) np[c2].insert(e);
}
for(int e : et[c1]) et[c2].insert(e);
et[c1].clear();
np[c1].clear();
cnt[c2]+=cnt[c1];
s.insert(pii(et[c2].size(), c2));
}
}
int main()
{
cin.tie(0);
ios::sync_with_stdio(false);
dep[0]=-1;
cin >> N >> K;
for(int i=1;i<N;i++)
{
int a, b;
cin >> a >> b;
v[a].push_back(b);
v[b].push_back(a);
}
for(int i=1;i<=N;i++) cin >> C[i], ct[C[i]].push_back(i);
for(int i=1;i<=K;i++) if(ct[i].size()==1)
{
cout << 0;
return 0;
}
dfs1(1);
for(int i=1;i<=K;i++)
{
root[i]=ct[i][0];
for(int ele : ct[i]) root[i]=lca(root[i], ele);
np[i].insert(par[root[i]]);
}
for(int i=1;i<=K;i++) s.insert(pii(ct[i].size(), i));
for(int i=1;i<=K;i++) for(int ele : ct[i]) et[i].insert(ele);
for(int i=1;i<=K;i++) for(int ele : ct[i]) if(C[par[ele]]!=i) np[i].insert(par[ele]);
for(int i=1;i<=K;i++) cnt[i]=1;
/*for(int i=1;i<=K;i++)
{
printf("i=%d: ", i);
for(int ele : np[i]) printf("%d ", ele);
printf("\n");
}*/
while(s.size())
{
int c1=s.begin()->se;
s.erase(s.begin());
if(np[c1].size()==1)
{
ans=min(ans, cnt[c1]);
continue;
}
auto it=np[c1].begin();
int x=*it;
if(x==par[root[c1]]) x=*(++it);
int c2=C[x];
onion(c1, c2);
/*for(int i=1;i<=K;i++)
{
printf("i=%d: ", i);
for(int ele : np[i]) printf("%d ", ele);
printf("\n");
}*/
}
cout << ans-1;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |