//chat gpt is gonna beat us :((
#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
// Optimized solution using Heavy-Light Decomposition + segment tree
// NOTE: internal node indices are 1..N. The public API (init/change) uses
// the problem's 0-based indices, so we convert inputs to 1-based here.
static int nGlobal, qGlobal;
static vector<vector<int>> g;
static vector<int> Fcol; // 1-based: Fcol[1..N] in {0,1,2}
static vector<int> parentv, depthv, heavy, head, pos, sz;
static int curPos;
static vector<int> initJ, initI; // subtree counts of color 0 and 2
static ll Jtot = 0, Itot = 0;
static vector<int> isOme; // node is omelette (F==1)
static vector<int> isParentOme; // node's parent is omelette
struct Node {
int cnt;
ll sumJ, sumI, sumJI; // sum over nodes in segment of subJ, subI, subJ*subI
int cntOme;
ll sumJ_ome, sumI_ome, sumJI_ome; // restricted to omelette nodes
int cntParentOme;
ll sumJ_parent, sumI_parent, sumJI_parent; // restricted to nodes whose parent is omelette
ll addJ, addI; // lazy adds to subJ/subI
Node(): cnt(0), sumJ(0), sumI(0), sumJI(0),
cntOme(0), sumJ_ome(0), sumI_ome(0), sumJI_ome(0),
cntParentOme(0), sumJ_parent(0), sumI_parent(0), sumJI_parent(0),
addJ(0), addI(0) {}
};
static vector<Node> seg;
void dfs_sz(int u,int p){
parentv[u]=p;
depthv[u]=(p==0?0:depthv[p]+1);
sz[u]=1; heavy[u]=-1;
initJ[u] = (Fcol[u]==0);
initI[u] = (Fcol[u]==2);
for(int v: g[u]) if(v!=p){
dfs_sz(v,u);
initJ[u] += initJ[v];
initI[u] += initI[v];
if(heavy[u]==-1 || sz[v] > sz[heavy[u]]) heavy[u]=v;
sz[u] += sz[v];
}
}
void dfs_hld(int u,int h){
head[u]=h; pos[u]=++curPos;
if(heavy[u]!=-1) dfs_hld(heavy[u], h);
for(int v: g[u]) if(v!=parentv[u] && v!=heavy[u]) dfs_hld(v, v);
}
Node mergeNode(const Node &a, const Node &b){
Node c;
c.cnt = a.cnt + b.cnt;
c.sumJ = a.sumJ + b.sumJ;
c.sumI = a.sumI + b.sumI;
c.sumJI = a.sumJI + b.sumJI;
c.cntOme = a.cntOme + b.cntOme;
c.sumJ_ome = a.sumJ_ome + b.sumJ_ome;
c.sumI_ome = a.sumI_ome + b.sumI_ome;
c.sumJI_ome = a.sumJI_ome + b.sumJI_ome;
c.cntParentOme = a.cntParentOme + b.cntParentOme;
c.sumJ_parent = a.sumJ_parent + b.sumJ_parent;
c.sumI_parent = a.sumI_parent + b.sumI_parent;
c.sumJI_parent = a.sumJI_parent + b.sumJI_parent;
return c;
}
void apply_add(Node &nd, ll dJ, ll dI){
if(dJ==0 && dI==0) return;
// sumJI <- sum((J+dJ)*(I+dI)) = sumJI + dJ*sumI + dI*sumJ + dJ*dI*cnt
nd.sumJI += dJ * nd.sumI + dI * nd.sumJ + dJ * dI * nd.cnt;
nd.sumJ += dJ * nd.cnt;
nd.sumI += dI * nd.cnt;
// omelette restricted
nd.sumJI_ome += dJ * nd.sumI_ome + dI * nd.sumJ_ome + dJ * dI * nd.cntOme;
nd.sumJ_ome += dJ * nd.cntOme;
nd.sumI_ome += dI * nd.cntOme;
// parent-ome restricted
nd.sumJI_parent += dJ * nd.sumI_parent + dI * nd.sumJ_parent + dJ * dI * nd.cntParentOme;
nd.sumJ_parent += dJ * nd.cntParentOme;
nd.sumI_parent += dI * nd.cntParentOme;
nd.addJ += dJ;
nd.addI += dI;
}
void push(int idx){
ll dJ = seg[idx].addJ, dI = seg[idx].addI;
if(dJ!=0 || dI!=0){
apply_add(seg[idx<<1], dJ, dI);
apply_add(seg[idx<<1|1], dJ, dI);
seg[idx].addJ = seg[idx].addI = 0;
}
}
void build(int idx,int l,int r, const vector<int> &baseNode){
if(l==r){
int u = baseNode[l];
seg[idx].cnt = 1;
ll j = initJ[u];
ll i = initI[u];
seg[idx].sumJ = j;
seg[idx].sumI = i;
seg[idx].sumJI = j * i;
seg[idx].cntOme = isOme[u];
seg[idx].sumJ_ome = isOme[u] ? j : 0;
seg[idx].sumI_ome = isOme[u] ? i : 0;
seg[idx].sumJI_ome = isOme[u] ? j*i : 0;
seg[idx].cntParentOme = isParentOme[u];
seg[idx].sumJ_parent = isParentOme[u] ? j : 0;
seg[idx].sumI_parent = isParentOme[u] ? i : 0;
seg[idx].sumJI_parent = isParentOme[u] ? j*i : 0;
seg[idx].addJ = seg[idx].addI = 0;
} else {
int mid=(l+r)>>1;
build(idx<<1,l,mid,baseNode);
build(idx<<1|1,mid+1,r,baseNode);
seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}
}
void range_add(int idx,int l,int r,int ql,int qr,ll dJ,ll dI){
if(ql>r || qr<l) return;
if(ql<=l && r<=qr){
apply_add(seg[idx], dJ, dI);
return;
}
push(idx);
int mid=(l+r)>>1;
range_add(idx<<1,l,mid,ql,qr,dJ,dI);
range_add(idx<<1|1,mid+1,r,ql,qr,dJ,dI);
seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}
void point_set_ome(int idx,int l,int r,int p,int val){
if(l==r){
seg[idx].cntOme = val;
seg[idx].sumJ_ome = val ? seg[idx].sumJ : 0;
seg[idx].sumI_ome = val ? seg[idx].sumI : 0;
seg[idx].sumJI_ome = val ? seg[idx].sumJI : 0;
return;
}
push(idx);
int mid=(l+r)>>1;
if(p<=mid) point_set_ome(idx<<1,l,mid,p,val);
else point_set_ome(idx<<1|1,mid+1,r,p,val);
seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}
void point_set_parentOme(int idx,int l,int r,int p,int val){
if(l==r){
seg[idx].cntParentOme = val;
seg[idx].sumJ_parent = val ? seg[idx].sumJ : 0;
seg[idx].sumI_parent = val ? seg[idx].sumI : 0;
seg[idx].sumJI_parent = val ? seg[idx].sumJI : 0;
return;
}
push(idx);
int mid=(l+r)>>1;
if(p<=mid) point_set_parentOme(idx<<1,l,mid,p,val);
else point_set_parentOme(idx<<1|1,mid+1,r,p,val);
seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}
inline ll get_sumJ(){ return seg[1].sumJ; }
inline ll get_sumI(){ return seg[1].sumI; }
inline ll get_sumJI(){ return seg[1].sumJI; }
inline int get_cntOme(){ return seg[1].cntOme; }
inline ll get_sumJ_ome(){ return seg[1].sumJ_ome; }
inline ll get_sumI_ome(){ return seg[1].sumI_ome; }
inline ll get_sumJI_ome(){ return seg[1].sumJI_ome; }
inline ll get_sumJI_parent(){ return seg[1].sumJI_parent; }
// add (dJ,dI) to all nodes on path root(=1) -> v
void path_add_root(int v, ll dJ, ll dI, int n){
while(v!=0){
int h = head[v];
range_add(1,1,n,pos[h], pos[v], dJ, dI);
v = parentv[h];
}
}
void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q){
nGlobal = N; qGlobal = Q;
g.assign(N+1, {});
Fcol.assign(N+1,0);
// map F (0-based vector) into Fcol (1-based)
for(int i=0;i<N;i++) Fcol[i+1] = F[i];
// build graph: convert input edges (0-based) to 1-based
for(int i=0;i<N-1;i++){
int u = U[i] + 1;
int v = V[i] + 1;
g[u].push_back(v);
g[v].push_back(u);
}
parentv.assign(N+1,0); depthv.assign(N+1,0); heavy.assign(N+1,-1); head.assign(N+1,0); pos.assign(N+1,0); sz.assign(N+1,0);
initJ.assign(N+1,0); initI.assign(N+1,0);
curPos = 0;
// root the tree at 1
dfs_sz(1,0);
dfs_hld(1,1);
// totals
Jtot = 0; Itot = 0;
for(int i=1;i<=N;i++){
if(Fcol[i]==0) Jtot++;
else if(Fcol[i]==2) Itot++;
}
// prepare omelette and parent-ome flags
isOme.assign(N+1,0);
isParentOme.assign(N+1,0);
for(int i=1;i<=N;i++) isOme[i] = (Fcol[i]==1);
for(int i=1;i<=N;i++){
if(parentv[i]!=0) isParentOme[i] = (Fcol[parentv[i]]==1);
}
// baseNode pos->node
vector<int> baseNode(N+1);
for(int i=1;i<=N;i++) baseNode[pos[i]] = i;
// build segment tree
seg.assign(4*(N+5), Node());
build(1,1,N,baseNode);
}
void change(int X, int Y){
// X given 0-based externally -> convert to 1-based
int x = X + 1;
int old = Fcol[x];
if(old == Y) return;
int dJ = 0, dI = 0;
if(old==0) dJ -= 1;
if(old==2) dI -= 1;
if(Y==0) dJ += 1;
if(Y==2) dI += 1;
// update subtree counts: all ancestors (incl. x) change by dJ,dI
if(dJ!=0 || dI!=0){
path_add_root(x, dJ, dI, nGlobal);
Jtot += dJ;
Itot += dI;
}
// update omelette flag at x and the parent flags of its children
if(old==1 && Y!=1){
// turning off omelette
isOme[x]=0;
point_set_ome(1,1,nGlobal,pos[x],0);
for(int v: g[x]) if(parentv[v]==x){
isParentOme[v]=0;
point_set_parentOme(1,1,nGlobal,pos[v],0);
}
} else if(old!=1 && Y==1){
// turning on omelette
isOme[x]=1;
point_set_ome(1,1,nGlobal,pos[x],1);
for(int v: g[x]) if(parentv[v]==x){
isParentOme[v]=1;
point_set_parentOme(1,1,nGlobal,pos[v],1);
}
}
Fcol[x] = Y;
}
long long num_tours(){
// answer = sum_{u:F[u]==1} (Jtot*Itot - S[u])
// compute sumS_ome and then answer
ll cntO = get_cntOme();
ll term1 = get_sumJI_parent(); // sum over nodes whose parent is omelette: subJ*subI
// term2 = sum_{u in ome} (Jtot - subJ[u])*(Itot - subI[u])
// expand: cntO*Jtot*Itot - Jtot*sumI_ome - Itot*sumJ_ome + sumJI_ome
ll term2 = cntO * (Jtot * Itot) - Jtot * get_sumI_ome() - Itot * get_sumJ_ome() + get_sumJI_ome();
ll sumS_ome = term1 + term2;
ll ans = cntO * (Jtot * Itot) - sumS_ome;
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |