제출 #1123325

#제출 시각아이디문제언어결과실행 시간메모리
1123325Math4Life2020JOI tour (JOI24_joitour)C++20
20 / 100
3100 ms800628 KiB
#pragma GCC optimize("O3,unroll-loops") #pragma GCC target("avx2") #include "joitour.h" #include <bits/stdc++.h> using namespace std; using ll = int; using pii = pair<ll,ll>; const ll Nm = 2e5+5; ll N; long long ans = 0; vector<int> F,U,V; vector<pii> locs[Nm]; //{index of subtree, index in subtree} ll rlbl[Nm]; //vector<ll> hld; struct cst { //cdt subtree ll M; vector<long long> emp; long long n0,n2,n21,n01; vector<ll> Fn; ll r; vector<vector<ll>> fadj; vector<ll> radj; vector<long long> v0,v2,v21,v01; void lft(ll x) { v0[x]=(Fn[x]==0); v2[x]=(Fn[x]==2); v21[x]=0; v01[x]=0; for (ll y: fadj[x]) { lft(y); v0[x]+=v0[y]; v2[x]+=v2[y]; v01[x]+=v01[y]; v21[x]+=v21[y]; } if (Fn[x]==1) { v01[x]+=v0[x]; v21[x]+=v2[x]; } } void calc() { v0=emp; v2=emp; v21=emp; v01=emp; lft(r); n0=v0[r]; n2=v2[r]; n21=v21[r]; n01=v01[r]; } cst(ll r0, ll M0, vector<vector<ll>> adj,vector<ll> f0) { n0=0; n2=0; n21=0; n01=0; r=r0; M=M0; Fn=f0; vector<bool> found; for (ll m=0;m<M;m++) { emp.push_back(0); radj.push_back(-1); found.push_back(0); fadj.push_back((vector<ll>){}); } queue<ll> q; q.push(r); while (!q.empty()) { ll x = q.front(); q.pop(); found[x]=1; for (ll y: adj[x]) { if (!found[y]) { radj[y]=x; q.push(y); fadj[x].push_back(y); } } } calc(); } void upd(ll x, ll v) { Fn[x]=v; calc(); } }; struct cdt { //centroid decomp tree ll M; //size vector<vector<ll>> fadj; vector<ll> Fn; //new F vector<pii> strl; //subtree locations: {index of st, index in st} vector<cst*> v1; long long s21=0, s01=0, s0=0, s2=0, s210=0, s012=0, s02=0; cdt(ll M1, vector<vector<ll>> fadj1, vector<ll> Fn1) { //fadj is really just adj oops M=M1; fadj=fadj1; Fn=Fn1; for (ll m=0;m<M;m++) { strl.push_back((pii){0,0}); } ll rcnt = 0; for (ll x: fadj[0]) { //unordered_map<ll,ll> rlbl; //relabel vector<vector<ll>> nadj; vector<ll> fnew; ll Mn = 0; queue<pii> q0; q0.push({x,-1}); //cout << "x="<<x<<"\n"; while (!q0.empty()) { pii p0 = q0.front(); q0.pop(); ll z = p0.first; ll pz = p0.second; if (z==0) { continue; } //if (rlbl.find(z)==rlbl.end()) { rlbl[z]=Mn++; //cout << "defining z="<<z<<" as "<<rlbl[z]<<"\n"; nadj.push_back((vector<ll>){}); fnew.push_back(Fn[z]); strl[z]={rcnt,rlbl[z]}; //cout << "relabel: z="<<z<<"->"<<rlbl[z]<<"\n"; //locs[z].push_back({dind,rlbl[z]}); // } if (pz != -1) { //cout << "z,pz="<<z<<","<<pz<<"\n"; nadj[rlbl[z]].push_back(rlbl[pz]); nadj[rlbl[pz]].push_back(rlbl[z]); } for (ll nz: fadj[z]) { if (nz != pz && nz != 0) { q0.push({nz,z}); } } } v1.push_back(new cst(0LL,Mn,nadj,fnew)); rcnt++; } for (ll r=0;r<rcnt;r++) { s21 += (v1[r]->n21); s01 += (v1[r]->n01); s0 += (v1[r]->n0); s2 += v1[r]->n2; s210 += (v1[r]->n21)*(v1[r]->n0); s012 += (v1[r]->n01)*(v1[r]->n2); s02 += (v1[r]->n0)*(v1[r]->n2); } ans += (s21*s0-s210+s01*s2-s012); if (Fn[0]==0) { ans += s21; } else if (Fn[0]==1) { ans += (s0*s2-s02); } else { ans += s01; } } void upd(ll x, ll vf) { if (x==0) { ll v0 = Fn[0]; if (v0==0) { ans -= s21; } else if (v0==1) { ans -= (s0*s2-s02); } else { assert(v0==2); ans -= s01; } if (vf==0) { ans += s21; } else if (vf==1) { ans += (s0*s2-s02); } else { assert(vf==2); ans += s01; } } else { ll v0 = Fn[0]; if (v0==0) { ans -= s21; } else if (v0==1) { ans -= (s0*s2-s02); } else { assert(v0==2); ans -= s01; } ans -= (s21*s0-s210+s01*s2-s012); ll i = strl[x].first; s21 -= (v1[i]->n21); s01 -= (v1[i]->n01); s2 -= (v1[i]->n2); s0 -= (v1[i]->n0); s210 -= (v1[i]->n21)*(v1[i]->n0); s012 -= (v1[i]->n01)*(v1[i]->n2); s02 -= (v1[i]->n0)*(v1[i]->n2); (*v1[i]).upd(strl[x].second,vf); i = strl[x].first; s21 += (v1[i]->n21); s01 += (v1[i]->n01); s2 += (v1[i]->n2); s0 += (v1[i]->n0); s210 += (v1[i]->n21)*(v1[i]->n0); s012 += (v1[i]->n01)*(v1[i]->n2); s02 += (v1[i]->n0)*(v1[i]->n2); v0 = Fn[0]; if (v0==0) { ans += s21; } else if (v0==1) { ans += (s0*s2-s02); } else { assert(v0==2); ans += s01; } ans += (s21*s0-s210+s01*s2-s012); } Fn[x]=vf; } }; vector<ll> adj[Nm]; bool found[Nm]; ll sz[Nm]; ll rev[Nm]; vector<cdt*> cdtr; ll getsz(ll x, ll pr = -1) { sz[x]=1; for (ll y: adj[x]) { if (y != pr && !found[y]) { sz[x]+=getsz(y,x); } } return sz[x]; } ll getctr(ll x, ll sz0, ll pr=-1) { for (ll y: adj[x]) { if (y != pr && !found[y]) { if (2*sz[y]>=sz0) { return getctr(y,sz0,x); } } } return x; } ll dind = 0; //index in cdtr void bldDcmp(ll x=0) { //start, previous ll sz0 = getsz(x); ll y = getctr(x,sz0); vector<vector<ll>> nadj; //new adjacency vector<ll> fnew; ll M = 0; queue<pii> q0; q0.push({y,-1}); while (!q0.empty()) { pii p0 = q0.front(); q0.pop(); ll z = p0.first; ll pz = p0.second; rlbl[z]=M++; nadj.push_back((vector<ll>){}); fnew.push_back(F[z]); locs[z].push_back({dind,rlbl[z]}); if (pz != -1) { nadj[rlbl[z]].push_back(rlbl[pz]); nadj[rlbl[pz]].push_back(rlbl[z]); } for (ll zn: adj[z]) { if (!found[zn] && zn != pz) { q0.push({zn,z}); } } } cdtr.push_back(new cdt(M,nadj,fnew)); found[y]=1; dind++; for (ll z: adj[y]) { if (!found[z]) { bldDcmp(z); } } } void init(int N1, vector<int> F1, vector<int> U1, vector<int> V1, int Q) { N=N1; F=F1; U=U1; V=V1; for (ll i=0;i<(N-1);i++) { adj[U[i]].push_back(V[i]); adj[V[i]].push_back(U[i]); } bldDcmp(); //build centroid decomposition } void change(int x, int y) { for (pii p0: locs[x]) { (*cdtr[p0.first]).upd(p0.second,y); } } long long num_tours() { return ans; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...