#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
constexpr int MAXN = 200000 + 5;
constexpr int MAXE = (MAXN * 2);
constexpr int SEG_SIZE = 4 * MAXN;
// static adjacency (singly linked)
static int head[MAXN], to[MAXE], nxt[MAXE], ecnt;
// tree / HLD arrays (1-based nodes internally)
static int nGlobal;
static int parentv[MAXN], depthv[MAXN], heavy[MAXN], headH[MAXN], pos[MAXN], szArr[MAXN];
static int curPos;
static int baseNode[MAXN]; // pos -> node
// colors and subtree initial counts
static int Fcol[MAXN]; // 1-based: 0/1/2
static int initJ[MAXN], initI[MAXN];
static ll Jtot, Itot;
// omelette flags
static int isOme[MAXN], isParentOme[MAXN];
// segment tree node (plain arrays / struct)
struct Node {
int cnt;
ll sumJ, sumI, sumJI;
int cntOme;
ll sumJ_ome, sumI_ome, sumJI_ome;
int cntParentOme;
ll sumJ_parent, sumI_parent, sumJI_parent;
ll addJ, addI;
} ;
static Node seg[SEG_SIZE];
// ---------- adjacency helpers ----------
static inline void add_edge(int u,int v){
++ecnt;
to[ecnt] = v;
nxt[ecnt] = head[u];
head[u] = ecnt;
}
// ---------- iterative dfs to compute sz, parent, depth, initJ/I, heavy ----------
static void dfs_sz_iter(int root){
// stack: pair(node, state). state 0 = enter, 1 = exit
static int st_node[MAXN];
static char st_state[MAXN];
int sp = 0;
st_node[sp] = root; st_state[sp] = 0; sp++;
parentv[root] = 0;
depthv[root] = 0;
while(sp){
--sp;
int u = st_node[sp];
char state = st_state[sp];
if(state == 0){
// enter
st_node[sp] = u; st_state[sp] = 1; sp++;
// push children
for(int e = head[u]; e; e = nxt[e]){
int v = to[e];
if(v == parentv[u]) continue;
parentv[v] = u;
depthv[v] = depthv[u] + 1;
st_node[sp] = v; st_state[sp] = 0; sp++;
}
} else {
// exit: compute sz, aggregate initJ/initI
szArr[u] = 1;
initJ[u] = (Fcol[u] == 0) ? 1 : 0;
initI[u] = (Fcol[u] == 2) ? 1 : 0;
int best = -1;
int bestsz = 0;
for(int e = head[u]; e; e = nxt[e]){
int v = to[e];
if(v == parentv[u]) continue;
szArr[u] += szArr[v];
initJ[u] += initJ[v];
initI[u] += initI[v];
if(szArr[v] > bestsz){ bestsz = szArr[v]; best = v; }
}
heavy[u] = best;
}
}
}
// ---------- iterative HLD decomposition ----------
static void dfs_hld_iter(int root){
curPos = 0;
// stack of chain heads to start
static int st[MAXN];
int sp = 0;
st[sp++] = root;
while(sp){
int u = st[--sp];
int h = u;
// go down heavy chain
for(int v = u; v != -1; v = heavy[v]){
headH[v] = h;
pos[v] = ++curPos;
baseNode[curPos] = v;
// push light children to stack for later processing
for(int e = head[v]; e; e = nxt[e]){
int c = to[e];
if(c == parentv[v] || c == heavy[v]) continue;
st[sp++] = c;
}
}
}
}
// ---------- segment tree helpers ----------
static inline Node mergeNode(const Node &a, const Node &b){
Node c;
c.cnt = a.cnt + b.cnt;
c.sumJ = a.sumJ + b.sumJ;
c.sumI = a.sumI + b.sumI;
c.sumJI = a.sumJI + b.sumJI;
c.cntOme = a.cntOme + b.cntOme;
c.sumJ_ome = a.sumJ_ome + b.sumJ_ome;
c.sumI_ome = a.sumI_ome + b.sumI_ome;
c.sumJI_ome = a.sumJI_ome + b.sumJI_ome;
c.cntParentOme = a.cntParentOme + b.cntParentOme;
c.sumJ_parent = a.sumJ_parent + b.sumJ_parent;
c.sumI_parent = a.sumI_parent + b.sumI_parent;
c.sumJI_parent = a.sumJI_parent + b.sumJI_parent;
c.addJ = c.addI = 0;
return c;
}
static inline void apply_add_node(Node &nd, ll dJ, ll dI){
if(dJ==0 && dI==0) return;
nd.sumJI += dJ * nd.sumI + dI * nd.sumJ + dJ * dI * nd.cnt;
nd.sumJ += dJ * nd.cnt;
nd.sumI += dI * nd.cnt;
nd.sumJI_ome += dJ * nd.sumI_ome + dI * nd.sumJ_ome + dJ * dI * nd.cntOme;
nd.sumJ_ome += dJ * nd.cntOme;
nd.sumI_ome += dI * nd.cntOme;
nd.sumJI_parent += dJ * nd.sumI_parent + dI * nd.sumJ_parent + dJ * dI * nd.cntParentOme;
nd.sumJ_parent += dJ * nd.cntParentOme;
nd.sumI_parent += dI * nd.cntParentOme;
nd.addJ += dJ;
nd.addI += dI;
}
static inline void push_down(int idx){
ll dJ = seg[idx].addJ;
ll dI = seg[idx].addI;
if(dJ!=0 || dI!=0){
apply_add_node(seg[idx<<1], dJ, dI);
apply_add_node(seg[idx<<1|1], dJ, dI);
seg[idx].addJ = seg[idx].addI = 0;
}
}
static void build_seg(int idx,int l,int r){
if(l==r){
int u = baseNode[l];
Node &nd = seg[idx];
nd.cnt = 1;
ll j = initJ[u];
ll i = initI[u];
nd.sumJ = j;
nd.sumI = i;
nd.sumJI = j * i;
nd.cntOme = isOme[u];
nd.sumJ_ome = isOme[u] ? j : 0;
nd.sumI_ome = isOme[u] ? i : 0;
nd.sumJI_ome = isOme[u] ? j * i : 0;
nd.cntParentOme = isParentOme[u];
nd.sumJ_parent = isParentOme[u] ? j : 0;
nd.sumI_parent = isParentOme[u] ? i : 0;
nd.sumJI_parent = isParentOme[u] ? j * i : 0;
nd.addJ = nd.addI = 0;
return;
}
int mid = (l + r) >> 1;
build_seg(idx<<1, l, mid);
build_seg(idx<<1|1, mid+1, r);
seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}
static void range_add(int idx,int l,int r,int ql,int qr,ll dJ,ll dI){
if(ql>r || qr<l) return;
if(ql<=l && r<=qr){
apply_add_node(seg[idx], dJ, dI);
return;
}
push_down(idx);
int mid = (l + r) >> 1;
if(ql <= mid) range_add(idx<<1, l, mid, ql, qr, dJ, dI);
if(qr > mid) range_add(idx<<1|1, mid+1, r, ql, qr, dJ, dI);
seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}
static void point_set_ome(int idx,int l,int r,int p,int val){
if(l==r){
seg[idx].cntOme = val;
seg[idx].sumJ_ome = val ? seg[idx].sumJ : 0;
seg[idx].sumI_ome = val ? seg[idx].sumI : 0;
seg[idx].sumJI_ome = val ? seg[idx].sumJI : 0;
return;
}
push_down(idx);
int mid = (l+r)>>1;
if(p<=mid) point_set_ome(idx<<1, l, mid, p, val);
else point_set_ome(idx<<1|1, mid+1, r, p, val);
seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}
static void point_set_parentOme(int idx,int l,int r,int p,int val){
if(l==r){
seg[idx].cntParentOme = val;
seg[idx].sumJ_parent = val ? seg[idx].sumJ : 0;
seg[idx].sumI_parent = val ? seg[idx].sumI : 0;
seg[idx].sumJI_parent = val ? seg[idx].sumJI : 0;
return;
}
push_down(idx);
int mid = (l+r)>>1;
if(p<=mid) point_set_parentOme(idx<<1, l, mid, p, val);
else point_set_parentOme(idx<<1|1, mid+1, r, p, val);
seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}
// getters
static inline ll get_sumJ(){ return seg[1].sumJ; }
static inline ll get_sumI(){ return seg[1].sumI; }
static inline ll get_sumJI(){ return seg[1].sumJI; }
static inline int get_cntOme(){ return seg[1].cntOme; }
static inline ll get_sumJ_ome(){ return seg[1].sumJ_ome; }
static inline ll get_sumI_ome(){ return seg[1].sumI_ome; }
static inline ll get_sumJI_ome(){ return seg[1].sumJI_ome; }
static inline ll get_sumJI_parent(){ return seg[1].sumJI_parent; }
// add (dJ,dI) on path root->v
static inline void path_add_root(int v, ll dJ, ll dI){
while(v != 0){
int h = headH[v];
range_add(1, 1, nGlobal, pos[h], pos[v], dJ, dI);
v = parentv[h];
}
}
// ---------- public API (matches joitour.h) ----------
void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q){
nGlobal = N;
// reset adjacency
ecnt = 0;
for(int i=1;i<=nGlobal;i++) head[i] = 0;
// fill Fcol (1-based)
for(int i=0;i<nGlobal;i++) Fcol[i+1] = F[i];
// build edges (input is 0-based -> convert to 1-based)
for(int i=0;i<nGlobal-1;i++){
int u = U[i] + 1;
int v = V[i] + 1;
add_edge(u,v);
add_edge(v,u);
}
// compute sizes, heavy, parent, depth, initJ/initI
dfs_sz_iter(1);
// HLD positions
dfs_hld_iter(1);
// totals
Jtot = 0; Itot = 0;
for(int i=1;i<=nGlobal;i++){
if(Fcol[i]==0) ++Jtot;
else if(Fcol[i]==2) ++Itot;
}
// omelette flags
for(int i=1;i<=nGlobal;i++) isOme[i] = (Fcol[i]==1);
for(int i=1;i<=nGlobal;i++){
int p = parentv[i];
isParentOme[i] = (p!=0 && Fcol[p]==1);
}
// build segment tree base
build_seg(1,1,nGlobal);
}
void change(int X, int Y){
int x = X + 1; // convert
int old = Fcol[x];
if(old == Y) return;
int dJ = 0, dI = 0;
if(old==0) --dJ;
if(old==2) --dI;
if(Y==0) ++dJ;
if(Y==2) ++dI;
if(dJ!=0 || dI!=0){
path_add_root(x, dJ, dI);
Jtot += dJ;
Itot += dI;
}
// update omelette flag and children parent flags
if(old==1 && Y!=1){
isOme[x] = 0;
point_set_ome(1,1,nGlobal,pos[x],0);
for(int e=head[x]; e; e=nxt[e]){
int v = to[e];
if(parentv[v] == x){
isParentOme[v] = 0;
point_set_parentOme(1,1,nGlobal,pos[v],0);
}
}
} else if(old!=1 && Y==1){
isOme[x] = 1;
point_set_ome(1,1,nGlobal,pos[x],1);
for(int e=head[x]; e; e=nxt[e]){
int v = to[e];
if(parentv[v] == x){
isParentOme[v] = 1;
point_set_parentOme(1,1,nGlobal,pos[v],1);
}
}
}
Fcol[x] = Y;
}
long long num_tours(){
// answer = sum_{u:F[u]==1} (Jtot*Itot - S[u])
ll cntO = get_cntOme();
ll term1 = get_sumJI_parent(); // sum of subJ*subI for nodes whose parent is omelette
// term2 = sum_{u in ome} (Jtot - subJ[u])*(Itot - subI[u])
// expand = cntO*Jtot*Itot - Jtot*sumI_ome - Itot*sumJ_ome + sumJI_ome
ll term2 = cntO * (Jtot * Itot) - Jtot * get_sumI_ome() - Itot * get_sumJ_ome() + get_sumJI_ome();
ll sumS_ome = term1 + term2;
ll ans = cntO * (Jtot * Itot) - sumS_ome;
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |