Submission #1123712

#TimeUsernameProblemLanguageResultExecution timeMemory
1123712Math4Life2020JOI tour (JOI24_joitour)C++20
64 / 100
3114 ms777432 KiB
#pragma GCC optimize("Ofast,unroll-loops") #pragma GCC target("avx2,popcnt,lzcnt,abm,bmi,bmi2,fma") #include "joitour.h" #include <bits/stdc++.h> using namespace std; using ll = int; using pii = pair<ll,ll>; const ll Nm = 2e5+5; ll N; long long ans = 0; vector<int> F,U,V; vector<pii> locs[Nm]; //{index of subtree, index in subtree} ll rlbl[Nm]; //vector<ll> hld; const ll Sm = 4194304; const ll E = 22; //long long sts0[2*Sm]; //already pushed sum s0 int na0[2*Sm]; //number active in 0 int pd0[2*Sm]; //number to push down //long long sts2[2*Sm]; int na2[2*Sm]; int pd2[2*Sm]; ll allc = 0; //allocator for memory ll v2(ll x) { return __builtin_ctz(x); } inline void pdn0(ll p) { //push down at position p //return; //sts0[p]+=stna0[p]*pd0[p]; pd0[2*p]+=pd0[p]; pd0[2*p+1]+=pd0[p]; pd0[p]=0; } inline void pdn2(ll p) { //push down at position p //return; //sts2[p]+=stna2[p]*pd2[p]; pd2[2*p]+=pd2[p]; pd2[2*p+1]+=pd2[p]; pd2[p]=0; } ll wrt0(ll x, ll v) { //v is the DELTA //return 0; for (ll e=16;e>0;e--) { pdn0((x>>e)+(1LL<<(E-e))); } //sts0[x+Sm] += v*pd0[x+Sm]; na0[x+Sm] += v; for (ll e=1;e<=16;e++) { ll p = ((x>>e)+(1LL<<(E-e))); //sts0[p] += v*pd0[x+Sm]; na0[p] += v; } return v*pd0[x+Sm]; } ll wrt2(ll x, ll v) { //return 0; for (ll e=16;e>0;e--) { pdn2((x>>e)+(1LL<<(E-e))); } // sts2[x+Sm] += v*pd2[x+Sm]; na2[x+Sm] += v; for (ll e=1;e<=16;e++) { ll p = ((x>>e)+(1<<(E-e))); //sts2[p] += v*pd2[x+Sm]; na2[p] += v; } return v*pd2[x+Sm]; } pii wrtI(ll x, ll y, ll v) { //return {n21, n01} updates //cout << "write 1 to range "<<x<<","<<y<<"\n"; //return {0,0}; if (x>y) { return {0,0}; } ll vx = v2(x); ll vy = v2(y+1); if (vx<vy) { ll p = (x>>vx)+(1<<(E-vx)); pii p1 = {v*na0[p],v*na2[p]}; pd0[p]+=v; pd2[p]+=v; pii p2 = wrtI(x+(1<<vx),y,v); return {p1.first+p2.first,p1.second+p2.second}; } else { ll p = (y>>vy)+(1<<(E-vy)); pii p1 = {v*na0[p],v*na2[p]}; pd0[p]+=v; pd2[p]+=v; pii p2 = wrtI(x,y-(1<<vy),v); return {p1.first+p2.first,p1.second+p2.second}; } } struct cst { //cdt subtree ll M; long long n0,n2,n21,n01; vector<ll> Fn; ll r; vector<vector<ll>> fadj; vector<ll> radj; vector<int> tsz; //subtree size vector<int> sti; //segtree index // vector<long long> v0,v2,v21,v01; /*void lft(ll x) { v0[x]=(Fn[x]==0); v2[x]=(Fn[x]==2); v21[x]=0; v01[x]=0; for (ll y: fadj[x]) { lft(y); v0[x]+=v0[y]; v2[x]+=v2[y]; v01[x]+=v01[y]; v21[x]+=v21[y]; } if (Fn[x]==1) { v01[x]+=v0[x]; v21[x]+=v2[x]; } } void calc() { vector<long long> emp(M,0); v0=emp; v2=emp; v21=emp; v01=emp; lft(r); n0=v0[r]; n2=v2[r]; n21=v21[r]; n01=v01[r]; v0.clear(); v2.clear(); v21.clear(); v01.clear(); }*/ cst(ll r0, ll M0, vector<vector<ll>> adj,vector<ll> f0) { n0=0; n2=0; n21=0; n01=0; r=r0; M=M0; Fn=f0; vector<bool> found; for (ll m=0;m<M;m++) { tsz.push_back(0); sti.push_back(0); radj.push_back(-1); found.push_back(0); fadj.push_back((vector<ll>){}); } // queue<ll> q; // q.push(r); // while (!q.empty()) { // ll x = q.front(); q.pop(); // found[x]=1; // for (ll y: adj[x]) { // if (!found[y]) { // radj[y]=x; // q.push(y); // fadj[x].push_back(y); // } // } // } stack<pii> q0; q0.push({r,0}); while (!q0.empty()) { pii p0 = q0.top(); q0.pop(); ll x = p0.first; ll t = p0.second; if (t==0) { found[x]=1; q0.push({x,1}); for (ll y: adj[x]) { if (!found[y]) { radj[y]=x; q0.push({y,0}); fadj[x].push_back(y); } } } else { tsz[x]=1; // cout << "x="<<x<<", Fn[x]="<<Fn[x]<<"\n"; for (ll y: adj[x]) { // cout << "y in fadj="<<y<<"\n"; if (radj[y]==x) { //cout << "f1\n"; tsz[x]+=tsz[y]; } } //cout << "tsz[x]="<<tsz[x]<<"\n"; sti[x]=allc; if (f0[x]==0) { n0++; n01 += wrt0(allc,1); //v is the DELTA } else if (f0[x]==2) { n2++; n21 += wrt2(allc,1); } else { pii p1mod = wrtI(allc-tsz[x]+1,allc,1); n01 = p1mod.first+n01; n21 = p1mod.second+n21; } allc++; } } //cout << "n0,n2,n01,n21="<<n0<<","<<n2<<","<<n01<<","<<n21<<"\n"; //calc(); } /*void upd(ll x, ll v) { Fn[x]=v; calc(); }*/ void upd(ll x, ll v) { if (Fn[x]==0) { n0--; n01 += wrt0(sti[x],-1); } else if (Fn[x]==2) { n2--; n21 += wrt2(sti[x],-1); } else { pii p1mod = wrtI(sti[x]-tsz[x]+1,sti[x],-1); n01 = p1mod.first+n01; n21 = p1mod.second+n21; } Fn[x]=v; if (Fn[x]==0) { n0++; n01 += wrt0(sti[x],1); } else if (Fn[x]==2) { n2++; n21 += wrt2(sti[x],1); } else { pii p1mod = wrtI(sti[x]-tsz[x]+1,sti[x],1); n01 = p1mod.first+n01; n21 = p1mod.second+n21; } // cout << "n0,n2,n01,n21="<<n0<<","<<n2<<","<<n01<<","<<n21<<"\n"; } }; struct cdt { //centroid decomp tree ll M; //size vector<vector<ll>> fadj; vector<ll> Fn; //new F vector<pii> strl; //subtree locations: {index of st, index in st} vector<cst*> v1; long long s21=0, s01=0, s0=0, s2=0, s210=0, s012=0, s02=0; cdt(ll M1, vector<vector<ll>> fadj1, vector<ll> Fn1) { //fadj is really just adj oops M=M1; fadj=fadj1; Fn=Fn1; for (ll m=0;m<M;m++) { strl.push_back((pii){0,0}); } ll rcnt = 0; for (ll x: fadj[0]) { //unordered_map<ll,ll> rlbl; //relabel vector<vector<ll>> nadj; vector<ll> fnew; ll Mn = 0; queue<pii> q0; q0.push({x,-1}); //cout << "x="<<x<<"\n"; while (!q0.empty()) { pii p0 = q0.front(); q0.pop(); ll z = p0.first; ll pz = p0.second; if (z==0) { continue; } //if (rlbl.find(z)==rlbl.end()) { rlbl[z]=Mn++; //cout << "defining z="<<z<<" as "<<rlbl[z]<<"\n"; nadj.push_back((vector<ll>){}); fnew.push_back(Fn[z]); strl[z]={rcnt,rlbl[z]}; //cout << "relabel: z="<<z<<"->"<<rlbl[z]<<"\n"; //locs[z].push_back({dind,rlbl[z]}); // } if (pz != -1) { //cout << "z,pz="<<z<<","<<pz<<"\n"; nadj[rlbl[z]].push_back(rlbl[pz]); nadj[rlbl[pz]].push_back(rlbl[z]); } for (ll nz: fadj[z]) { if (nz != pz && nz != 0) { q0.push({nz,z}); } } } v1.push_back(new cst(0LL,Mn,nadj,fnew)); rcnt++; } for (ll r=0;r<rcnt;r++) { s21 += (v1[r]->n21); s01 += (v1[r]->n01); s0 += (v1[r]->n0); s2 += v1[r]->n2; s210 += (v1[r]->n21)*(v1[r]->n0); s012 += (v1[r]->n01)*(v1[r]->n2); s02 += (v1[r]->n0)*(v1[r]->n2); } ans += (s21*s0-s210+s01*s2-s012); if (Fn[0]==0) { ans += s21; } else if (Fn[0]==1) { ans += (s0*s2-s02); } else { ans += s01; } } void upd(ll x, ll vf) { if (x==0) { ll v0 = Fn[0]; if (v0==0) { ans -= s21; } else if (v0==1) { ans -= (s0*s2-s02); } else { assert(v0==2); ans -= s01; } if (vf==0) { ans += s21; } else if (vf==1) { ans += (s0*s2-s02); } else { assert(vf==2); ans += s01; } } else { ll v0 = Fn[0]; if (v0==0) { ans -= s21; } else if (v0==1) { ans -= (s0*s2-s02); } else { assert(v0==2); ans -= s01; } ans -= (s21*s0-s210+s01*s2-s012); ll i = strl[x].first; s21 -= (v1[i]->n21); s01 -= (v1[i]->n01); s2 -= (v1[i]->n2); s0 -= (v1[i]->n0); s210 -= (v1[i]->n21)*(v1[i]->n0); s012 -= (v1[i]->n01)*(v1[i]->n2); s02 -= (v1[i]->n0)*(v1[i]->n2); (*v1[i]).upd(strl[x].second,vf); i = strl[x].first; s21 += (v1[i]->n21); s01 += (v1[i]->n01); s2 += (v1[i]->n2); s0 += (v1[i]->n0); s210 += (v1[i]->n21)*(v1[i]->n0); s012 += (v1[i]->n01)*(v1[i]->n2); s02 += (v1[i]->n0)*(v1[i]->n2); v0 = Fn[0]; if (v0==0) { ans += s21; } else if (v0==1) { ans += (s0*s2-s02); } else { assert(v0==2); ans += s01; } ans += (s21*s0-s210+s01*s2-s012); } Fn[x]=vf; } }; vector<ll> adj[Nm]; bool found[Nm]; ll sz[Nm]; ll rev[Nm]; vector<cdt*> cdtr; ll getsz(ll x, ll pr = -1) { sz[x]=1; for (ll y: adj[x]) { if (y != pr && !found[y]) { sz[x]+=getsz(y,x); } } return sz[x]; } ll getctr(ll x, ll sz0, ll pr=-1) { for (ll y: adj[x]) { if (y != pr && !found[y]) { if (2*sz[y]>=sz0) { return getctr(y,sz0,x); } } } return x; } ll dind = 0; //index in cdtr void bldDcmp(ll x=0) { //start, previous ll sz0 = getsz(x); ll y = getctr(x,sz0); vector<vector<ll>> nadj; //new adjacency vector<ll> fnew; ll M = 0; queue<pii> q0; q0.push({y,-1}); while (!q0.empty()) { pii p0 = q0.front(); q0.pop(); ll z = p0.first; ll pz = p0.second; rlbl[z]=M++; nadj.push_back((vector<ll>){}); fnew.push_back(F[z]); locs[z].push_back({dind,rlbl[z]}); if (pz != -1) { nadj[rlbl[z]].push_back(rlbl[pz]); nadj[rlbl[pz]].push_back(rlbl[z]); } for (ll zn: adj[z]) { if (!found[zn] && zn != pz) { q0.push({zn,z}); } } } cdtr.push_back(new cdt(M,nadj,fnew)); found[y]=1; dind++; for (ll z: adj[y]) { if (!found[z]) { bldDcmp(z); } } } void init(int N1, vector<int> F1, vector<int> U1, vector<int> V1, int Q) { N=N1; F=F1; U=U1; V=V1; for (ll i=0;i<(N-1);i++) { adj[U[i]].push_back(V[i]); adj[V[i]].push_back(U[i]); } bldDcmp(); //build centroid decomposition } void change(int x, int y) { for (pii p0: locs[x]) { (*cdtr[p0.first]).upd(p0.second,y); } } long long num_tours() { return ans; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...