Submission #1160757

#TimeUsernameProblemLanguageResultExecution timeMemory
1160757Math4Life2020Capital City (JOI20_capital_city)C++20
100 / 100
1519 ms292352 KiB
#include <bits/stdc++.h>
using namespace std;

using ll = long long; using pii = pair<ll,ll>;

const ll Nm = 1048576; const ll E = 20; const ll INF = 1e18;
ll ans = INF;

pii stlca[2*Nm]; //segtree for LCA
vector<ll> adj[Nm];
vector<ll> fadj[Nm];
ll ht[Nm];
ll radj[Nm];
ll floc[Nm];
ll C[Nm];
ll lcc[Nm]; //least common ancestor of a given C
ll xroot[Nm]; //root x of a given value
ll spos[Nm]; //position on segtree
ll ispos[Nm];
ll sts[Nm]; //subtree size
vector<pii> vroot[Nm]; //points with this root: {height, point}
set<ll> segf[Nm]; //index of segtree elements that this number points to
vector<bool> found(3*Nm,0); //found -> backtrack and fuse until you find this point
vector<bool> prc(3*Nm,0); //processed -> immediately terminate entire sequence

//NEW CONVENTION: refer to a segtree term as Nm+index
//refer to a certain color element as index

inline ll v2(ll x) {
    return __builtin_ctz(x);
}

inline pii fz(pii p1, pii p2) {
    if (p2.second<p1.second) {
        return p2;
    }
    return p1;
}

pii gmin(ll l, ll r) {
    if (l>r) {
        return {-1,INF};
    }
    ll vl = v2(l); ll vr = v2(r+1);
    if (vl<vr) {
        return fz(stlca[(l>>vl)+(1<<(E-vl))],gmin(l+(1<<vl),r));
    } else {
        return fz(stlca[(r>>vr)+(1<<(E-vr))],gmin(l,r-(1<<vr)));
    }
}

ll lca(ll a, ll b) {
    if (a==-1) {
        return b;
    }
    if (b==-1) {
        return a;
    }
    if (a==b) {
        return a;
    }
    //cout << "running lca("<<a<<","<<b<<")\n";
    assert(floc[a]!=INF);
    assert(floc[b]!=INF);
    ll x1 = floc[a]; ll x2 = floc[b];
    //cout << "x1,x2="<<x1<<","<<x2<<"\n";
    if (x1>x2) {
        swap(x1,x2);
    } //thus x1<=x2
    return gmin(x1,x2).first;
}

ll getsts(ll x) {
    ll val = 1;
    for (ll y: fadj[x]) {
        val += getsts(y);
    }
    sts[x]=val;
    return val;
}

void wdcmp(ll x) {
    vroot[xroot[x]].push_back({ht[x],x});
    if (fadj[x].size()==0) {
        return;
    }
    ll ym = -1; ll szm = -1;
    for (ll y: fadj[x]) {
        if (sts[y]>szm) {
            ym = y; szm = sts[y];
        }
    }
    for (ll y: fadj[x]) {
        if (y==ym) {
            xroot[y]=xroot[x];
            wdcmp(y);
        } else {
            xroot[y]=y;
            wdcmp(y);
        }
    }
}

void wst0(ll l, ll r, ll c0) {
    //cout << "wts0 call: "<<l<<","<<r<<","<<c0<<"\n";
    if (l>r) {
        return;
    }
    ll vl = v2(l); ll vr = v2(r+1);
    if (vl<vr) {
        //cout << "pushing: "<<((l>>vl)+(1<<(E-vl)))<<"\n";
        segf[c0].insert((l>>vl)+(1<<(E-vl)));
        wst0(l+(1<<vl),r,c0);
    } else {
        //cout << "pushing: "<<((r>>vr)+(1<<(E-vr)))<<"\n";
        segf[c0].insert((r>>vr)+(1<<(E-vr)));
        wst0(l,r-(1<<vr),c0);
    }
}

void wst(ll x, ll yf, ll c0) {
    if (x==-1 || ht[x]<ht[yf]) {
        return;
    }
    //cout << "wst: "<<x<<","<<yf<<","<<c0<<"\n";
    ll rt = xroot[x];
    if (ht[rt]>=ht[yf]) {
        //cout << "wts0 points: "<<rt<<","<<x<<","<<c0<<"\n";
        wst0(spos[rt],spos[x],c0);
        wst(radj[rt],yf,c0);
    } else {
        //cout << "wts0 points: "<<yf<<","<<x<<","<<c0<<"\n";
        wst0(spos[yf],spos[x],c0);
    }
}

vector<set<ll>> celemST; //current elements (segtree terms / colors)
vector<set<ll>> cfwd; //current forward degrees

void exec() {
    if (cfwd.back().size()==0) {
        ll numCol = 0;
        for (ll x0: celemST.back()) {
            if (x0<Nm) {
                numCol++;
            }
        }
        ans = min(ans,numCol);
        for (auto A0: celemST) {
            for (ll y: A0) {
                prc[y]=1;
            }
        }
        celemST.clear();
        cfwd.clear();
        return;
    }
    // cout << "current state:\n";
    // for (ll x0: celemST.back()) {
    //     cout << "celemST: "<<x0<<"\n";
    // }
    // for (ll x0: celemP.back()) {
    //     cout << "celemP: "<<x0<<"\n";
    // }
    // for (ll x0: cfwd.back()) {
    //     cout << "cfwd: "<<x0<<"\n";
    // }
    ll ynew = *((cfwd.back()).begin()); //cout << "ynew="<<ynew<<"\n";
    cfwd.back().erase(ynew);
    if (prc[ynew]) {
        for (auto A0: celemST) {
            for (ll y: A0) {
                prc[y]=1;
            }
        }
        celemST.clear();
        cfwd.clear();
        return;
    }
    if (found[ynew]) {
        while (1) {
            if (celemST.back().find(ynew)!=celemST.back().end()) {
                break;
            }
            ll T = celemST.size();
            for (ll z: celemST[T-1]) {
                celemST[T-2].insert(z);
            }
            for (ll z: cfwd[T-1]) {
                cfwd[T-2].insert(z);
            }
            celemST.pop_back();
            cfwd.pop_back();
            //cout << "backtrack\n";
        }
    } else {
        found[ynew]=1;
        if (ynew<Nm) {
            celemST.push_back((set<ll>){ynew});
            set<ll> cfn;
            for (ll z: segf[ynew]) {
                //cout << "ynew="<<ynew<<", z="<<z<<"\n";
                cfn.insert(z+Nm);
            }
            cfwd.push_back(cfn);
        } else if (ynew<(2*Nm)) {
            celemST.push_back((set<ll>){ynew});
            //cout << "terms: "<<(2*ynew-Nm)<<","<<(2*ynew-Nm+1)<<"\n";
            cfwd.push_back((set<ll>){2*ynew-Nm,2*ynew-Nm+1});
        } else {
            celemST.push_back((set<ll>){ynew});
            cfwd.push_back((set<ll>){C[ispos[ynew-2*Nm]]});
        }
    }
    if (cfwd.back().size()==0) {
        ll numCol = 0;
        for (ll x0: celemST.back()) {
            if (x0<Nm) {
                numCol++;
            }
        }
        ans = min(ans,numCol);
        for (auto A0: celemST) {
            for (ll y: A0) {
                prc[y]=1;
            }
        }
        celemST.clear();
        cfwd.clear();
        return;
    }
    exec();
}

int main() {
	ios_base::sync_with_stdio(false); cin.tie(0);
    ll N,K; cin >> N >> K;
    if (K==1) {
        cout << "0\n"; exit(0);
    }
    for (ll i=0;i<Nm;i++) {
        floc[i]=INF;
        lcc[i]=-1;
    }
    for (ll i=0;i<(N-1);i++) {
        ll a,b; cin >> a >> b;
        a--; b--;
        adj[a].push_back(b);
        adj[b].push_back(a);
    }
    ht[0]=0;
    radj[0]=-1;
    stack<pii> s0;
    s0.push({0,0});
    vector<pii> ett; //euler tour
    while (!s0.empty()) {
        pii pt = s0.top(); s0.pop();
        ett.push_back({pt.first,ht[pt.first]});
        if (pt.second==0) {
            ll x = pt.first;
            for (ll y: adj[x]) {
                if (y!=radj[x]) {
                    fadj[x].push_back(y);
                    radj[y]=x;
                    ht[y]=ht[x]+1;
                    s0.push({x,1});
                    s0.push({y,0});
                }
            }
        }
    }
    for (ll i=0;i<((ll)ett.size());i++) {
        stlca[i+Nm]=ett[i];
        //cout << "ett: "<<ett[i].first<<","<<ett[i].second<<"\n";
        floc[ett[i].first]=min(floc[ett[i].first],i);
    }
    for (ll p=(Nm-1);p>=1;p--) {
        stlca[p]=fz(stlca[2*p],stlca[2*p+1]);
    }
    for (ll i=0;i<N;i++) {
        cin >> C[i]; C[i]--;
        //cout << "init = "<<lcc[C[i]]<<","<<i<<"; lca=";
        lcc[C[i]]=lca(lcc[C[i]],i); //cout << lcc[C[i]] <<"\n";
    }
    for (ll i=0;i<K;i++) {
        //cout << "lcc[i="<<i<<"]="<<lcc[i]<<"\n";
    }
    getsts(0);
    xroot[0]=0;
    wdcmp(0); //write the decomposition 
    ll IC = 0;
    for (ll i=0;i<Nm;i++) {
        sort(vroot[i].begin(),vroot[i].end());
        for (pii p0: vroot[i]) {
            spos[p0.second]=(IC);
            ispos[IC]=p0.second;
            IC++;
        }
    }
    for (ll i=0;i<N;i++) {
        //cout << "spos[i="<<i<<"]="<<spos[i]<<"\n";
    }
    for (ll i=1;i<N;i++) {
        wst(i,lcc[C[i]],C[i]); //write segtree
    }
    for (ll k=0;k<K;k++) {
        //cout << "k="<<k<<", segtree terms: \n";
        for (ll z: segf[k]) {
           // cout << "z="<<z<<"\n";
        }
    }
    for (ll i=0;i<N;i++) {
        if (!prc[2*Nm+i]) {
            //cout << "process i="<<i<<"\n";
            celemST.clear();
            cfwd.clear();
            set<ll> sprim; sprim.insert(i+2*Nm);
            //cout << "sprim elem: "<<(i+2*Nm)<<"\n";
            set<ll> cprim; cprim.insert(C[ispos[i]]);
            celemST.push_back(sprim);
            cfwd.push_back(cprim);
            found[i+2*Nm]=1;
            exec(); //execute
        }
    }
    cout << (ans-1) << "\n";
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...