# | Time | Username | Problem | Language | Result | Execution time | Memory |
---|---|---|---|---|---|---|---|
1123312 | Math4Life2020 | JOI tour (JOI24_joitour) | C++20 | 0 ms | 0 KiB |
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2")
#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
using ll = int; using pii = pair<ll,ll>;
const ll Nm = 2e5+5;
ll N;
long long ans = 0;
vector<int> F,U,V;
vector<pii> locs[Nm]; //{index of subtree, index in subtree}
ll rlbl[Nm];
//vector<ll> hld;
struct cst { //cdt subtree
ll M;
vector<long long> emp;
long long n0,n2,n21,n01;
vector<ll> Fn;
ll r;
vector<vector<ll>> fadj;
vector<ll> radj;
vector<long long> v0,v2,v21,v01;
void lft(ll x) {
v0[x]=(Fn[x]==0);
v2[x]=(Fn[x]==2);
v21[x]=0;
v01[x]=0;
for (ll y: fadj[x]) {
lft(y);
v0[x]+=v0[y];
v2[x]+=v2[y];
v01[x]+=v01[y];
v21[x]+=v21[y];
}
if (Fn[x]==1) {
v01[x]+=v0[x];
v21[x]+=v2[x];
}
}
void calc() {
v0=emp; v2=emp; v21=emp; v01=emp;
lft(r);
n0=v0[r]; n2=v2[r]; n21=v21[r]; n01=v01[r];
}
cst(ll r0, ll M0, vector<vector<ll>> adj,vector<ll> f0) {
n0=0; n2=0; n21=0; n01=0;
r=r0; M=M0;
Fn=f0;
vector<bool> found;
for (ll m=0;m<M;m++) {
emp.push_back(0);
radj.push_back(-1);
found.push_back(0);
fadj.push_back((vector<ll>){});
}
queue<ll> q;
q.push(r);
while (!q.empty()) {
ll x = q.front(); q.pop();
found[x]=1;
for (ll y: adj[x]) {
if (!found[y]) {
radj[y]=x;
q.push(y);
fadj[x].push_back(y);
}
}
}
calc();
}
void upd(ll x, ll v) {
Fn[x]=v;
calc();
}
};
struct cdt { //centroid decomp tree
ll M; //size
vector<vector<ll>> fadj;
vector<ll> Fn; //new F
vector<pii> strl; //subtree locations: {index of st, index in st}
vector<cst*> v1;
long long s21=0, s01=0, s0=0, s2=0, s210=0, s012=0, s02=0;
cdt(ll M1, vector<vector<ll>> fadj1, vector<ll> Fn1) { //fadj is really just adj oops
M=M1; fadj=fadj1; Fn=Fn1;
for (ll m=0;m<M;m++) {
strl.push_back((pii){0,0});
}
ll rcnt = 0;
for (ll x: fadj[0]) {
//unordered_map<ll,ll> rlbl; //relabel
vector<vector<ll>> nadj;
vector<ll> fnew;
ll Mn = 0;
queue<pii> q0;
q0.push({x,-1});
//cout << "x="<<x<<"\n";
while (!q0.empty()) {
pii p0 = q0.front(); q0.pop();
ll z = p0.first; ll pz = p0.second;
if (z==0) {
continue;
}
//if (rlbl.find(z)==rlbl.end()) {
rlbl[z]=Mn++;
//cout << "defining z="<<z<<" as "<<rlbl[z]<<"\n";
nadj.push_back((vector<ll>){});
fnew.push_back(Fn[z]);
strl[z]={rcnt,rlbl[z]};
//cout << "relabel: z="<<z<<"->"<<rlbl[z]<<"\n";
//locs[z].push_back({dind,rlbl[z]});
// }
if (pz != -1) {
//cout << "z,pz="<<z<<","<<pz<<"\n";
nadj[rlbl[z]].push_back(rlbl[pz]);
nadj[rlbl[pz]].push_back(rlbl[z]);
}
for (ll nz: fadj[z]) {
if (nz != pz && nz != 0) {
q0.push({nz,z});
}
}
}
v1.push_back(new cst(0LL,Mn,nadj,fnew));
rcnt++;
}
for (ll r=0;r<rcnt;r++) {
ll a21 = (v1[r]->n21);
ll a01 = (v1[r]->n01);
ll a0 = (v1[r]->n0);
ll a2 = v1[r]->n2;
s21 += a21
s01 += a01;
s0 += a0;
s2 += a2;
s210 += a21*a0;
s012 += a01*a2;
s02 += a0*a2;
}
ans += (s21*s0-s210+s01*s2-s012);
if (Fn[0]==0) {
ans += s21;
} else if (Fn[0]==1) {
ans += (s0*s2-s02);
} else {
ans += s01;
}
}
void upd(ll x, ll vf) {
if (x==0) {
ll v0 = Fn[0];
if (v0==0) {
ans -= s21;
} else if (v0==1) {
ans -= (s0*s2-s02);
} else {
assert(v0==2);
ans -= s01;
}
if (vf==0) {
ans += s21;
} else if (vf==1) {
ans += (s0*s2-s02);
} else {
assert(vf==2);
ans += s01;
}
} else {
ll v0 = Fn[0];
if (v0==0) {
ans -= s21;
} else if (v0==1) {
ans -= (s0*s2-s02);
} else {
assert(v0==2);
ans -= s01;
}
ans -= (s21*s0-s210+s01*s2-s012);
ll i = strl[x].first;
ll& a21 = (v1[i]->n21);
ll& a01 = (v1[i]->n01);
ll& a2 = (v1[i]->n2);
ll& a0 = (v1[i]->n0);
s21 -= a21;
s01 -= a01;
s2 -= a2;
s0 -= a0;
s210 -= a21*a0;
s012 -= a01*a2;
s02 -= a0*a2;
(*v1[i]).upd(strl[x].second,vf);
s21 += a21;
s01 += a01;
s2 += a2;
s0 += a0;
s210 += a21*a0;
s012 += a01*a2;
s02 += a0*a2;
v0 = Fn[0];
if (v0==0) {
ans += s21;
} else if (v0==1) {
ans += (s0*s2-s02);
} else {
assert(v0==2);
ans += s01;
}
ans += (s21*s0-s210+s01*s2-s012);
}
Fn[x]=vf;
}
};
vector<ll> adj[Nm];
bool found[Nm];
ll sz[Nm];
ll rev[Nm];
vector<cdt*> cdtr;
ll getsz(ll x, ll pr = -1) {
sz[x]=1;
for (ll y: adj[x]) {
if (y != pr && !found[y]) {
sz[x]+=getsz(y,x);
}
}
return sz[x];
}
ll getctr(ll x, ll sz0, ll pr=-1) {
for (ll y: adj[x]) {
if (y != pr && !found[y]) {
if (2*sz[y]>sz0) {
return getctr(y,sz0,x);
}
}
}
return x;
}
ll dind = 0; //index in cdtr
void bldDcmp(ll x=0) { //start, previous
ll sz0 = getsz(x);
ll y = getctr(x,sz0);
vector<vector<ll>> nadj; //new adjacency
vector<ll> fnew;
ll M = 0;
queue<pii> q0;
q0.push({y,-1});
while (!q0.empty()) {
pii p0 = q0.front(); q0.pop();
ll z = p0.first; ll pz = p0.second;
rlbl[z]=M++;
nadj.push_back((vector<ll>){});
fnew.push_back(F[z]);
locs[z].push_back({dind,rlbl[z]});
if (pz != -1) {
nadj[rlbl[z]].push_back(rlbl[pz]);
nadj[rlbl[pz]].push_back(rlbl[z]);
}
for (ll zn: adj[z]) {
if (!found[zn] && zn != pz) {
q0.push({zn,z});
}
}
}
cdtr.push_back(new cdt(M,nadj,fnew));
found[y]=1;
dind++;
for (ll z: adj[y]) {
if (!found[z]) {
bldDcmp(z);
}
}
}
void init(int N1, vector<int> F1, vector<int> U1, vector<int> V1, int Q) {
N=N1;
F=F1;
U=U1;
V=V1;
for (ll i=0;i<(N-1);i++) {
adj[U[i]].push_back(V[i]);
adj[V[i]].push_back(U[i]);
}
bldDcmp(); //build centroid decomposition
}
void change(int x, int y) {
for (pii p0: locs[x]) {
(*cdtr[p0.first]).upd(p0.second,y);
}
}
long long num_tours() {
return ans;
}