#include <bits/stdc++.h>
using namespace std;
using ll = long long; using pii = pair<ll,ll>;
const ll Nm = 1048576; const ll E = 20; const ll INF = 1e18;
ll ans = INF;
pii stlca[2*Nm]; //segtree for LCA
vector<ll> adj[Nm];
vector<ll> fadj[Nm];
ll ht[Nm];
ll radj[Nm];
ll floc[Nm];
ll C[Nm];
ll lcc[Nm]; //least common ancestor of a given C
ll xroot[Nm]; //root x of a given value
ll spos[Nm]; //position on segtree
ll ispos[Nm];
ll sts[Nm]; //subtree size
vector<pii> vroot[Nm]; //points with this root: {height, point}
set<ll> segf[Nm]; //index of segtree elements that this number points to
vector<bool> found(3*Nm,0); //found -> backtrack and fuse until you find this point
vector<bool> prc(3*Nm,0); //processed -> immediately terminate entire sequence
//NEW CONVENTION: refer to a segtree term as Nm+index
//refer to a certain color element as index
inline ll v2(ll x) {
return __builtin_ctz(x);
}
inline pii fz(pii p1, pii p2) {
if (p2.second<p1.second) {
return p2;
}
return p1;
}
pii gmin(ll l, ll r) {
if (l>r) {
return {-1,INF};
}
ll vl = v2(l); ll vr = v2(r+1);
if (vl<vr) {
return fz(stlca[(l>>vl)+(1<<(E-vl))],gmin(l+(1<<vl),r));
} else {
return fz(stlca[(r>>vr)+(1<<(E-vr))],gmin(l,r-(1<<vr)));
}
}
ll lca(ll a, ll b) {
if (a==-1) {
return b;
}
if (b==-1) {
return a;
}
if (a==b) {
return a;
}
//cout << "running lca("<<a<<","<<b<<")\n";
assert(floc[a]!=INF);
assert(floc[b]!=INF);
ll x1 = floc[a]; ll x2 = floc[b];
//cout << "x1,x2="<<x1<<","<<x2<<"\n";
if (x1>x2) {
swap(x1,x2);
} //thus x1<=x2
return gmin(x1,x2).first;
}
ll getsts(ll x) {
ll val = 1;
for (ll y: fadj[x]) {
val += getsts(y);
}
sts[x]=val;
return val;
}
void wdcmp(ll x) {
vroot[xroot[x]].push_back({ht[x],x});
if (fadj[x].size()==0) {
return;
}
ll ym = -1; ll szm = -1;
for (ll y: fadj[x]) {
if (sts[y]>szm) {
ym = y; szm = sts[y];
}
}
for (ll y: fadj[x]) {
if (y==ym) {
xroot[y]=xroot[x];
wdcmp(y);
} else {
xroot[y]=y;
wdcmp(y);
}
}
}
void wst0(ll l, ll r, ll c0) {
//cout << "wts0 call: "<<l<<","<<r<<","<<c0<<"\n";
if (l>r) {
return;
}
ll vl = v2(l); ll vr = v2(r+1);
if (vl<vr) {
segf[c0].insert((l>>vl)+(1<<(E-vl)));
wst0(l+(1<<vl),r,c0);
} else {
//cout << "pushing: "<<((r>>vr)+(1<<(E-vr)))<<"\n";
segf[c0].insert((r>>vr)+(1<<(E-vr)));
wst0(l,r-(1<<vr),c0);
}
}
void wst(ll x, ll yf, ll c0) {
if (x==-1 || ht[x]<ht[yf]) {
return;
}
ll rt = xroot[x];
if (ht[rt]>=ht[yf]) {
//cout << "wts0 points: "<<rt<<","<<x<<","<<c0<<"\n";
wst0(spos[rt],spos[x],c0);
wst(radj[rt],yf,c0);
} else {
//cout << "wts0 points: "<<yf<<","<<x<<","<<c0<<"\n";
wst0(spos[yf],spos[x],c0);
}
}
vector<set<ll>> celemST; //current elements (segtree terms / colors)
vector<set<ll>> celemP; //current elements (points)
vector<set<ll>> cfwd; //current forward degrees
void exec() {
if (cfwd.back().size()==0) {
ans = min(ans,(ll)celemP.back().size());
for (auto A0: celemST) {
for (ll y: A0) {
prc[y]=1;
}
}
for (auto A0: celemP) {
for (ll y: A0) {
prc[y]=1;
}
}
celemST.clear();
celemP.clear();
cfwd.clear();
return;
}
// cout << "current state:\n";
// for (ll x0: celemST.back()) {
// cout << "celemST: "<<x0<<"\n";
// }
// for (ll x0: celemP.back()) {
// cout << "celemP: "<<x0<<"\n";
// }
// for (ll x0: cfwd.back()) {
// cout << "cfwd: "<<x0<<"\n";
// }
ll ynew = *((cfwd.back()).begin()); //cout << "ynew="<<ynew<<"\n";
cfwd.back().erase(ynew);
if (prc[ynew]) {
for (auto A0: celemST) {
for (ll y: A0) {
prc[y]=1;
}
}
for (auto A0: celemP) {
for (ll y: A0) {
prc[y]=1;
}
}
celemST.clear();
celemP.clear();
cfwd.clear();
return;
}
if (found[ynew]) {
while (1) {
if (celemST.back().find(ynew)!=celemST.back().end() || celemP.back().find(ynew)!=celemP.back().end()) {
break;
}
ll T = celemST.size();
for (ll z: celemST[T-1]) {
celemST[T-2].insert(z);
}
for (ll z: celemP[T-1]) {
celemP[T-2].insert(z);
}
for (ll z: cfwd[T-1]) {
cfwd[T-2].insert(z);
}
celemST.pop_back();
celemP.pop_back();
cfwd.pop_back();
//cout << "backtrack\n";
}
} else {
found[ynew]=1;
if (ynew<Nm) {
celemST.push_back((set<ll>){ynew});
celemP.push_back((set<ll>){});
set<ll> cfn;
for (ll z: segf[ynew]) {
//cout << "ynew="<<ynew<<", z="<<z<<"\n";
cfn.insert(z+Nm);
}
cfwd.push_back(cfn);
} else if (ynew<(2*Nm)) {
celemST.push_back((set<ll>){ynew});
celemP.push_back((set<ll>){});
//cout << "terms: "<<(2*ynew-Nm)<<","<<(2*ynew-Nm+1)<<"\n";
cfwd.push_back((set<ll>){2*ynew-Nm,2*ynew-Nm+1});
} else {
celemST.push_back((set<ll>){});
celemP.push_back((set<ll>){ynew});
cfwd.push_back((set<ll>){C[ispos[ynew-2*Nm]]});
}
}
if (cfwd.back().size()==0) {
ll numCol = 0;
for (ll x0: celemST.back()) {
if (x0<Nm) {
numCol++;
}
}
ans = min(ans,numCol);
for (auto A0: celemST) {
for (ll y: A0) {
prc[y]=1;
}
}
for (auto A0: celemP) {
for (ll y: A0) {
prc[y]=1;
}
}
celemST.clear();
celemP.clear();
cfwd.clear();
return;
}
exec();
}
int main() {
ios_base::sync_with_stdio(false); cin.tie(0);
ll N,K; cin >> N >> K;
if (K==1) {
cout << "0\n"; exit(0);
}
for (ll i=0;i<Nm;i++) {
floc[i]=INF;
lcc[i]=-1;
}
for (ll i=0;i<(N-1);i++) {
ll a,b; cin >> a >> b;
a--; b--;
adj[a].push_back(b);
adj[b].push_back(a);
}
ht[0]=0;
radj[0]=-1;
stack<pii> s0;
s0.push({0,0});
vector<pii> ett; //euler tour
while (!s0.empty()) {
pii pt = s0.top(); s0.pop();
ett.push_back({pt.first,ht[pt.first]});
if (pt.second==0) {
ll x = pt.first;
for (ll y: adj[x]) {
if (y!=radj[x]) {
fadj[x].push_back(y);
radj[y]=x;
ht[y]=ht[x]+1;
s0.push({x,1});
s0.push({y,0});
}
}
}
}
for (ll i=0;i<((ll)ett.size());i++) {
stlca[i+Nm]=ett[i];
//cout << "ett: "<<ett[i].first<<","<<ett[i].second<<"\n";
floc[ett[i].first]=min(floc[ett[i].first],i);
}
for (ll p=(Nm-1);p>=1;p--) {
stlca[p]=fz(stlca[2*p],stlca[2*p+1]);
}
for (ll i=0;i<N;i++) {
cin >> C[i]; C[i]--;
//cout << "init = "<<lcc[C[i]]<<","<<i<<"; lca=";
lcc[C[i]]=lca(lcc[C[i]],i); //cout << lcc[C[i]] <<"\n";
}
for (ll i=0;i<K;i++) {
//cout << "lcc[i="<<i<<"]="<<lcc[i]<<"\n";
}
getsts(0);
xroot[0]=0;
wdcmp(0); //write the decomposition
ll IC = 0;
for (ll i=0;i<Nm;i++) {
sort(vroot[i].begin(),vroot[i].end());
for (pii p0: vroot[i]) {
spos[p0.second]=(IC);
ispos[IC]=p0.second;
IC++;
}
}
for (ll i=0;i<N;i++) {
//cout << "spos[i="<<i<<"]="<<spos[i]<<"\n";
}
for (ll i=1;i<N;i++) {
wst(i,lcc[C[i]],C[i]); //write segtree
}
for (ll k=0;k<K;k++) {
//cout << "k="<<k<<", segtree terms: \n";
for (ll z: segf[k]) {
// cout << "z="<<z<<"\n";
}
}
for (ll i=0;i<N;i++) {
if (!prc[2*Nm+i]) {
//cout << "process i="<<i<<"\n";
celemST.clear();
celemP.clear();
cfwd.clear();
set<ll> sprim; sprim.insert(i+2*Nm);
//cout << "sprim elem: "<<(i+2*Nm)<<"\n";
set<ll> cprim; cprim.insert(C[i]);
set<ll> eprim;
celemP.push_back(sprim);
cfwd.push_back(cprim);
celemST.push_back(eprim);
found[i+2*Nm]=1;
exec(); //execute
}
}
cout << (ans-1) << "\n";
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |