# | Time | Username | Problem | Language | Result | Execution time | Memory |
---|---|---|---|---|---|---|---|
1122992 | Math4Life2020 | JOI tour (JOI24_joitour) | C++20 | 0 ms | 0 KiB |
#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long; using pii = pair<ll,ll>;
const ll Nm = 2e5+5; const ll INF = 1e18;
ll N;
ll ans = 0;
vector<int> F,U,V;
vector<pii> locs[Nm]; //{index of subtree, index in subtree}
//vector<ll> hld;
struct cst { //cdt subtree
ll M;
vector<ll> emp;
ll n0,n2,n21,n01;
vector<ll> Fn;
ll r;
vector<vector<ll>> fadj;
vector<ll> radj;
vector<ll> v0,v2,v21,v01;
void lft(ll x) {
v0[x]=(Fn[x]==0);
v2[x]=(Fn[x]==2);
v21[x]=0;
v01[x]=0;
for (ll y: fadj[x]) {
lft(y);
v0[x]+=v0[y];
v2[x]+=v2[y];
v01[x]+=v01[y];
v21[x]+=v21[y];
}
if (Fn[x]==1) {
v01[x]+=v0[x];
v21[x]+=v2[x];
}
}
void calc() {
v0=emp; v2=emp; v21=emp; v01=emp;
lft(r);
n0=v0[r]; n2=v2[r]; n21=v21[r]; n01=v01[r];
}
cst(ll r0, ll M0, vector<vector<ll>> adj,vector<ll> f0) {
n0=0; n2=0; n21=0; n01=0;
r=r0; M=M0;
// cout << "cst M = "<<M<<"\n";
// cout << "f0 elem=\n";
// for (ll x: f0) {
// cout << x << " ";
// }
// cout << "\n nadj:\n";
// for (ll x=0;x<adj.size();x++) {
// for (ll y: adj[x]) {
// cout << "x,y="<<x<<","<<y<<"\n";
// }
// }
Fn=f0;
vector<bool> found;
for (ll m=0;m<M;m++) {
emp.push_back(0);
radj.push_back(-1);
found.push_back(0);
fadj.push_back((vector<ll>){});
}
queue<ll> q;
q.push(r);
while (!q.empty()) {
ll x = q.front(); q.pop();
found[x]=1;
for (ll y: adj[x]) {
if (!found[y]) {
radj[y]=x;
q.push(y);
fadj[x].push_back(y);
}
}
}
// cout << "fadj:\n";
// for (ll x=0;x<fadj.size();x++) {
// for (ll y: fadj[x]) {
// cout << "x,y="<<x<<","<<y<<"\n";
// }
// }
calc();
}
void upd(ll x, ll v) {
Fn[x]=v;
//cout << "x,v="<<x<<","<<v<<"\n";
calc();
}
};
struct cdt { //centroid decomp tree
ll M; //size
vector<vector<ll>> fadj;
vector<ll> Fn; //new F
vector<pii> strl; //subtree locations: {index of st, index in st}
vector<cst*> v1;
ll s21=0, s01=0, s0=0, s2=0, s210=0, s012=0, s02=0;
cdt(ll M1, vector<vector<ll>> fadj1, vector<ll> Fn1) { //fadj is really just adj oops
M=M1; fadj=fadj1; Fn=Fn1;
// for (ll x: Fn) {
// cout << "fn term = "<<x<<"\n";
// }
for (ll m=0;m<M;m++) {
strl.push_back((pii){0,0});
}
// cout << "fadj: \n";
// for (ll x=0;x<fadj1.size();x++) {
// for (ll y: fadj1[x]) {
// cout << "x,y: "<<x<<","<<y<<"\n";
// }
// }
//root is 0
ll rcnt = 0;
for (ll x: fadj[0]) {
map<ll,ll> rlbl; //relabel
vector<vector<ll>> nadj;
vector<ll> fnew;
ll Mn = 0;
queue<pii> q0;
q0.push({x,-1});
//cout << "x="<<x<<"\n";
while (!q0.empty()) {
pii p0 = q0.front(); q0.pop();
ll z = p0.first; ll pz = p0.second;
if (z==0) {
continue;
}
if (rlbl.find(z)==rlbl.end()) {
rlbl[z]=Mn++;
//cout << "defining z="<<z<<" as "<<rlbl[z]<<"\n";
nadj.push_back((vector<ll>){});
fnew.push_back(Fn[z]);
strl[z]={rcnt,rlbl[z]};
//cout << "relabel: z="<<z<<"->"<<rlbl[z]<<"\n";
//locs[z].push_back({dind,rlbl[z]});
}
if (pz != -1) {
//cout << "z,pz="<<z<<","<<pz<<"\n";
nadj[rlbl[z]].push_back(rlbl[pz]);
nadj[rlbl[pz]].push_back(rlbl[z]);
}
for (ll nz: fadj[z]) {
if (nz != pz && nz != 0) {
q0.push({nz,z});
}
}
}
v1.push_back(new cst(0LL,Mn,nadj,fnew));
rcnt++;
}
for (ll r=0;r<rcnt;r++) {
s21 += (v1[r]->n21);
s01 += (v1[r]->n01);
s0 += (v1[r]->n0);
s2 += v1[r]->n2;
s210 += (v1[r]->n21)*(v1[r]->n0);
s012 += (v1[r]->n01)*(v1[r]->n2);
s02 += (v1[r]->n0)*(v1[r]->n2);
}
ans += (s21*s0-s210+s01*s2-s012);
if (Fn[0]==0) {
ans += s21;
} else if (Fn[0]==1) {
ans += (s0*s2-s02);
} else {
ans += s01;
}
}
void upd(ll x, ll vf) {
// cout << "updating where M="<<M<<"\n";
// cout << "initial ans = "<<ans<<"\n";
if (x==0) {
ll v0 = Fn[0];
if (v0==0) {
ans -= s21;
} else if (v0==1) {
ans -= (s0*s2-s02);
} else {
assert(v0==2);
ans -= s01;
}
if (vf==0) {
ans += s21;
} else if (vf==1) {
ans += (s0*s2-s02);
} else {
assert(vf==2);
ans += s01;
}
} else {
ll v0 = Fn[0];
if (v0==0) {
ans -= s21;
} else if (v0==1) {
ans -= (s0*s2-s02);
} else {
assert(v0==2);
ans -= s01;
}
ans -= (s21*s0-s210+s01*s2-s012);
ll i = strl[x].first;
s21 -= (v1[i]->n21);
s01 -= (v1[i]->n01);
s2 -= (v1[i]->n2);
s0 -= (v1[i]->n0);
s210 -= (v1[i]->n21)*(v1[i]->n0);
s012 -= (v1[i]->n01)*(v1[i]->n2);
s02 -= (v1[i]->n01)*(v1[i]->n2);
(*v1[i]).upd(strl[x].second,vf);
i = strl[x].first;
s21 += (v1[i]->n21);
s01 += (v1[i]->n01);
s2 += (v1[i]->n2);
//cout << "final s21="<<s21<<"\n";
s0 += (v1[i]->n0);
s210 += (v1[i]->n21)*(v1[i]->n0);
s012 += (v1[i]->n01)*(v1[i]->n2);
s02 += (v1[i]->n01)*(v1[i]->n2);
v0 = Fn[0];
if (v0==0) {
ans += s21;
} else if (v0==1) {
ans += (s0*s2-s02);
} else {
assert(v0==2);
ans += s01;
}
ans += (s21*s0-s210+s01*s2-s012);
}
Fn[x]=vf;
//cout << "final ans = "<<ans<<"\n";
}
};
vector<ll> adj[Nm];
bool found[Nm];
ll sz[Nm];
ll rev[Nm];
vector<cdt*> cdtr;
ll getsz(ll x, ll pr = -1) {
sz[x]=1;
for (ll y: adj[x]) {
if (y != pr && !found[y]) {
sz[x]+=getsz(y,x);
}
}
return sz[x];
}
ll getctr(ll x, ll sz0, ll pr=-1) {
for (ll y: adj[x]) {
if (y != pr && !found[y]) {
if (2*sz[y]>sz0) {
return getctr(y,sz0,x);
}
}
}
return x;
}
ll dind = 0; //index in cdtr
void bldDcmp(ll x=0) { //start, previous
ll sz0 = getsz(x);
ll y = getctr(x,sz0);
// cout << "centroid at y="<<y<<"\n";
map<ll,ll> rlbl; //relabel
vector<vector<ll>> nadj; //new adjacency
vector<ll> fnew;
ll M = 0;
queue<pii> q0;
q0.push({y,-1});
while (!q0.empty()) {
pii p0 = q0.front(); q0.pop();
ll z = p0.first; ll pz = p0.second;
if (rlbl.find(z)==rlbl.end()) {
rlbl[z]=M++;
nadj.push_back((vector<ll>){});
fnew.push_back(F[z]);
locs[z].push_back({dind,rlbl[z]});
}
if (pz != -1) {
nadj[rlbl[z]].push_back(rlbl[pz]);
nadj[rlbl[pz]].push_back(rlbl[z]);
}
for (ll zn: adj[z]) {
if (!found[zn] && zn != pz) {
q0.push({zn,z});
}
}
}
cdtr.push_back(new cdt(M,nadj,fnew));
found[y]=1;
dind++;
for (ll z: adj[y]) {
if (!found[z]) {
bldDcmp(z);
}
}
}
void init(int N1, vector<int> F1, vector<int> U1, vector<int> V1) {
N=N1;
F=F1;
U=U1;
V=V1;
for (ll i=0;i<(N-1);i++) {
adj[U[i]].push_back(V[i]);
adj[V[i]].push_back(U[i]);
}
bldDcmp(); //build centroid decomposition
}
void change(ll x, ll y) {
for (pii p0: locs[x]) {
//cout << "update at "<<p0.first<<","<<p0.second<<"\n";
(*cdtr[p0.first]).upd(p0.second,y);
}
}
ll num_tours() {
return ans;
}
/*int main() {
//ios_base::sync_with_stdio(false); cin.tie(0);
ll N0; cin >> N0;
vector<int> F1;
for (ll x=0;x<N0;x++) {
ll y; cin >> y; F1.push_back(y);
}
vector<int> U1,V1;
for (ll i=0;i<(N0-1);i++) {
ll x,y; cin >> x >> y;
U1.push_back(x);
V1.push_back(y);
}
init((int)N0,F1,U1,V1);
ll Q; cin >> Q;
cout << num_tours() <<"\n";
for (ll q=0;q<Q;q++) {
ll x,y; cin >> x >> y;
change(x,y);
cout << num_tours() <<"\n";
}
}*/