Submission #1123030

#TimeUsernameProblemLanguageResultExecution timeMemory
1123030Math4Life2020JOI tour (JOI24_joitour)C++20
Compilation error
0 ms0 KiB
#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long; using pii = pair<ll,ll>;
const ll Nm = 2e5+5; const ll INF = 1e18;

ll N;
ll ans = 0;
vector<int> F,U,V;
vector<pii> locs[Nm]; //{index of subtree, index in subtree}
ll rlbl[Nm];
//vector<ll> hld;

struct cst { //cdt subtree
    ll M;
    vector<ll> emp;
    ll n0,n2,n21,n01;
    vector<ll> Fn;
    ll r;
    vector<vector<ll>> fadj;
    vector<ll> radj;

    vector<ll> v0,v2,v21,v01;
    void lft(ll x) {
        v0[x]=(Fn[x]==0);
        v2[x]=(Fn[x]==2);
        v21[x]=0;
        v01[x]=0;
        for (ll y: fadj[x]) {
            lft(y);
            v0[x]+=v0[y];
            v2[x]+=v2[y];
            v01[x]+=v01[y];
            v21[x]+=v21[y];
        }
        if (Fn[x]==1) {
            v01[x]+=v0[x];
            v21[x]+=v2[x];
        }
    }
    void calc() {
        v0=emp; v2=emp; v21=emp; v01=emp;
        lft(r);
        n0=v0[r]; n2=v2[r]; n21=v21[r]; n01=v01[r];
    }
    cst(ll r0, ll M0, vector<vector<ll>> adj,vector<ll> f0) {
        n0=0; n2=0; n21=0; n01=0;
        r=r0; M=M0;
        // cout << "cst M = "<<M<<"\n";
        // cout << "f0 elem=\n";
        // for (ll x: f0) {
        //     cout << x << " ";
        // }
        // cout << "\n nadj:\n";
        // for (ll x=0;x<adj.size();x++) {
        //     for (ll y: adj[x]) {
        //         cout << "x,y="<<x<<","<<y<<"\n";
        //     }
        // }
        Fn=f0;
        vector<bool> found;
        for (ll m=0;m<M;m++) {
            emp.push_back(0);
            radj.push_back(-1);
            found.push_back(0);
            fadj.push_back((vector<ll>){});
        }
        queue<ll> q;
        q.push(r);
        while (!q.empty()) {
            ll x = q.front(); q.pop();
            found[x]=1;
            for (ll y: adj[x]) {
                if (!found[y]) {
                    radj[y]=x;
                    q.push(y);
                    fadj[x].push_back(y);
                }
            }
        }
        // cout << "fadj:\n";
        // for (ll x=0;x<fadj.size();x++) {
        //     for (ll y: fadj[x]) {
        //         cout << "x,y="<<x<<","<<y<<"\n";
        //     }
        // }
        calc();
    }
    void upd(ll x, ll v) {
        Fn[x]=v;
        //cout << "x,v="<<x<<","<<v<<"\n";
        calc();
    }
}; 

struct cdt { //centroid decomp tree
    ll M; //size
    vector<vector<ll>> fadj;
    vector<ll> Fn; //new F
    vector<pii> strl; //subtree locations: {index of st, index in st}
    vector<cst*> v1;
    ll s21=0, s01=0, s0=0, s2=0, s210=0, s012=0, s02=0;
    cdt(ll M1, vector<vector<ll>> fadj1, vector<ll> Fn1) { //fadj is really just adj oops
        M=M1; fadj=fadj1; Fn=Fn1;
        // for (ll x: Fn) {
        //     cout << "fn term = "<<x<<"\n";
        // }
        for (ll m=0;m<M;m++) {
            strl.push_back((pii){0,0});
        }
        // cout << "fadj: \n";
        // for (ll x=0;x<fadj1.size();x++) {
        //     for (ll y: fadj1[x]) {
        //         cout << "x,y: "<<x<<","<<y<<"\n";
        //     }
        // }
        //root is 0
        ll rcnt = 0;
        for (ll x: fadj[0]) {
            //unordered_map<ll,ll> rlbl; //relabel
            vector<vector<ll>> nadj;
            vector<ll> fnew;
            ll Mn = 0;
            queue<pii> q0;
            q0.push({x,-1});
            //cout << "x="<<x<<"\n";
            while (!q0.empty()) {
                pii p0 = q0.front(); q0.pop();
                ll z = p0.first; ll pz = p0.second;
                if (z==0) {
                    continue;
                }
                //if (rlbl.find(z)==rlbl.end()) {
                    rlbl[z]=Mn++;
                    //cout << "defining z="<<z<<" as "<<rlbl[z]<<"\n";
                    nadj.push_back((vector<ll>){});
                    fnew.push_back(Fn[z]);
                    strl[z]={rcnt,rlbl[z]};
                    //cout << "relabel: z="<<z<<"->"<<rlbl[z]<<"\n";
                    //locs[z].push_back({dind,rlbl[z]});
               // }
                if (pz != -1) {
                    //cout << "z,pz="<<z<<","<<pz<<"\n";
                    nadj[rlbl[z]].push_back(rlbl[pz]);
                    nadj[rlbl[pz]].push_back(rlbl[z]);
                }
                for (ll nz: fadj[z]) {
                    if (nz != pz && nz != 0) {
                        q0.push({nz,z});
                    }
                }
            }
            v1.push_back(new cst(0LL,Mn,nadj,fnew));
            rcnt++;
        }
        for (ll r=0;r<rcnt;r++) {
            s21 += (v1[r]->n21);
            s01 += (v1[r]->n01);
            s0 += (v1[r]->n0);
            s2 += v1[r]->n2;
            s210 += (v1[r]->n21)*(v1[r]->n0);
            s012 += (v1[r]->n01)*(v1[r]->n2);
            s02 += (v1[r]->n0)*(v1[r]->n2);
        }
        ans += (s21*s0-s210+s01*s2-s012);
        if (Fn[0]==0) {
            ans += s21;
        } else if (Fn[0]==1) {
            ans += (s0*s2-s02);
        } else {
            ans += s01;
        }
    }
    void upd(ll x, ll vf) {
        // cout << "updating where M="<<M<<"\n";
        // cout << "initial ans = "<<ans<<"\n";
        if (x==0) {
            ll v0 = Fn[0];
            if (v0==0) {
                ans -= s21;
            } else if (v0==1) {
                ans -= (s0*s2-s02);
            } else {
                assert(v0==2);
                ans -= s01;
            }
            if (vf==0) {
                ans += s21;
            } else if (vf==1) {
                ans += (s0*s2-s02);
            } else {
                assert(vf==2);
                ans += s01;
            }
        } else {
            ll v0 = Fn[0];
            if (v0==0) {
                ans -= s21;
            } else if (v0==1) {
                ans -= (s0*s2-s02);
            } else {
                assert(v0==2);
                ans -= s01;
            }
            ans -= (s21*s0-s210+s01*s2-s012);

            ll i = strl[x].first;
            s21 -= (v1[i]->n21);
            s01 -= (v1[i]->n01);
            s2 -= (v1[i]->n2);
            s0 -= (v1[i]->n0);
            s210 -= (v1[i]->n21)*(v1[i]->n0);
            s012 -= (v1[i]->n01)*(v1[i]->n2);
            s02 -= (v1[i]->n0)*(v1[i]->n2);

            (*v1[i]).upd(strl[x].second,vf);

            i = strl[x].first;
            s21 += (v1[i]->n21);
            s01 += (v1[i]->n01);
            s2 += (v1[i]->n2);
            s0 += (v1[i]->n0);
            s210 += (v1[i]->n21)*(v1[i]->n0);
            s012 += (v1[i]->n01)*(v1[i]->n2);
            s02 += (v1[i]->n0)*(v1[i]->n2);

            v0 = Fn[0];
            if (v0==0) {
                ans += s21;
            } else if (v0==1) {
                ans += (s0*s2-s02);
            } else {
                assert(v0==2);
                ans += s01;
            }
            ans += (s21*s0-s210+s01*s2-s012);
        }
        Fn[x]=vf;
        //cout << "final ans = "<<ans<<"\n";
    }
};

vector<ll> adj[Nm];
bool found[Nm];
ll sz[Nm]; 
ll rev[Nm];
vector<cdt*> cdtr;

ll getsz(ll x, ll pr = -1) {
    sz[x]=1;
    for (ll y: adj[x]) {
        if (y != pr && !found[y]) {
            sz[x]+=getsz(y,x);
        }
    }
    return sz[x];
}

ll getctr(ll x, ll sz0, ll pr=-1) {
    for (ll y: adj[x]) {
        if (y != pr && !found[y]) {
            if (2*sz[y]>sz0) {
                return getctr(y,sz0,x);
            }
        }
    }
    return x;
}

ll dind = 0; //index in cdtr

void bldDcmp(ll x=0) { //start, previous
    ll sz0 = getsz(x);
    ll y = getctr(x,sz0);
    //cout << "centroid at y="<<y<<"\n";
    //unordered_map<ll,ll> rlbl; //relabel
    vector<vector<ll>> nadj; //new adjacency
    vector<ll> fnew;
    ll M = 0;
    queue<pii> q0;
    q0.push({y,-1});
    while (!q0.empty()) {
        pii p0 = q0.front(); q0.pop();
        ll z = p0.first; ll pz = p0.second;
        //if (rlbl.find(z)==rlbl.end()) {
            rlbl[z]=M++;
            nadj.push_back((vector<ll>){});
            fnew.push_back(F[z]);
            locs[z].push_back({dind,rlbl[z]});
        //}
        if (pz != -1) {
            nadj[rlbl[z]].push_back(rlbl[pz]);
            nadj[rlbl[pz]].push_back(rlbl[z]);
        }
        for (ll zn: adj[z]) {
            if (!found[zn] && zn != pz) {
                q0.push({zn,z});
            }
        }
    }
    cdtr.push_back(new cdt(M,nadj,fnew));
    found[y]=1;
    dind++;
    for (ll z: adj[y]) {
        if (!found[z]) {
            bldDcmp(z);
        }
    }
}

void init(int N1, vector<int> F1, vector<int> U1, vector<int> V1, int Q) {
    N=N1;
    F=F1;
    U=U1;
    V=V1;
    for (ll i=0;i<(N-1);i++) {
        adj[U[i]].push_back(V[i]);
        adj[V[i]].push_back(U[i]);
    }
    bldDcmp(); //build centroid decomposition
}

void change(int x, int y) {
    for (pii p0: locs[x]) {
        //cout << "update at "<<p0.first<<","<<p0.second<<"\n";
        (*cdtr[p0.first]).upd(p0.second,y);
    }
}

ll num_tours() {
    return ans;
}

int main() {
    //ios_base::sync_with_stdio(false); cin.tie(0);
    ll N0; cin >> N0;
    vector<int> F1;
    for (ll x=0;x<N0;x++) {
        ll y; cin >> y; F1.push_back(y);
    }
    vector<int> U1,V1;
    for (ll i=0;i<(N0-1);i++) {
        ll x,y; cin >> x >> y;
        U1.push_back(x);
        V1.push_back(y);
    }
    init((int)N0,F1,U1,V1,-1);
    ll Q; cin >> Q;
    cout << num_tours() <<"\n";
    for (ll q=0;q<Q;q++) {
        ll x,y; cin >> x >> y;
        change(x,y);
        cout << num_tours() <<"\n";
    }
}

/*void memclear(ll Nt) {
    ans = 0;
    F.clear();
    U.clear();
    V.clear();
    cdtr.clear();
    dind = 0;
    for (ll i=0;i<Nt;i++) {
        locs[i].clear();
        adj[i].clear();
        found[i]=0;
        sz[i]=0;
        rev[i]=-1;
    }
}
ll count(vector<int> vqry, ll N1) {
    ll cnt0=0;
    for (ll i=0;i<N1;i++) {
        for (ll j=(i+1);j<N1;j++) {
            for (ll k=(j+1);k<N1;k++) {
                if (vqry[i]==0 && vqry[j]==1 && vqry[k]==2) {
                    cnt0++;
                }
                if (vqry[i]==2 && vqry[j]==1 && vqry[k]==0) {
                    cnt0++;
                }
            }
        }
    }
    return cnt0;
}

int main() {
    mt19937 gen((ll) new char);
    ll N1 = 4;
    vector<int> U1,V1;
    for (ll i=0;i<(N1-1);i++) {
        U1.push_back(i);
        V1.push_back(i+1);
    }
    for (ll T=0;T<(20000);T++) {
        memclear(2*N1);
        vector<int> vqry,vqry0;
        for (ll i=0;i<N1;i++) {
            vqry.push_back(gen()%3);
        }
        vqry0=vqry;
        init((int)N1,vqry,U1,V1,-1);
        vector<pii> vupd;
        for (ll T1=0;T1<1;T1++) {
            ll x = gen()%N1; ll v = gen()%3;
            if (vqry[x]==v){
                continue;
            }
            //cout << "x,vqry[x],v="<<x<<","<<vqry[x]<<","<<v<<"\n";
            change(x,v);
            vqry[x]=v;
            vupd.push_back({x,v});
            if (num_tours()!=count(vqry,N1)) {
                cout << "T="<<T<<"\n";
                cout << "init vqry: \n";
                for (ll x: vqry0) {
                    cout << x << " ";
                }
                cout << "\n";
                cout << "updates: \n";
                for (pii p0: vupd) {
                    cout << p0.first << " "<<p0.second<<"\n";
                }
                cout << "expected, true="<<num_tours()<<", "<<count(vqry,N1)<<"\n";
                exit(0);
            }
        }
    }
}*/

Compilation message (stderr)

/usr/bin/ld: /tmp/ccdExyA0.o: in function `main':
stub.cpp:(.text.startup+0x0): multiple definition of `main'; /tmp/ccm8Annf.o:joitour.cpp:(.text.startup+0x0): first defined here
collect2: error: ld returned 1 exit status