Submission #1269488

#TimeUsernameProblemLanguageResultExecution timeMemory
1269488MateiKing80JOI tour (JOI24_joitour)C++20
52 / 100
3034 ms77520 KiB
#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

constexpr int MAXN = 200000 + 5;
constexpr int MAXE = (MAXN * 2);
constexpr int SEG_SIZE = 4 * MAXN;

// static adjacency (singly linked)
static int head[MAXN], to[MAXE], nxt[MAXE], ecnt;

// tree / HLD arrays (1-based nodes internally)
static int nGlobal;
static int parentv[MAXN], depthv[MAXN], heavy[MAXN], headH[MAXN], pos[MAXN], szArr[MAXN];
static int curPos;
static int baseNode[MAXN]; // pos -> node

// colors and subtree initial counts
static int Fcol[MAXN]; // 1-based: 0/1/2
static int initJ[MAXN], initI[MAXN];
static ll Jtot, Itot;

// omelette flags
static int isOme[MAXN], isParentOme[MAXN];

// segment tree node (plain arrays / struct)
struct Node {
    int cnt;
    ll sumJ, sumI, sumJI;
    int cntOme;
    ll sumJ_ome, sumI_ome, sumJI_ome;
    int cntParentOme;
    ll sumJ_parent, sumI_parent, sumJI_parent;
    ll addJ, addI;
} ;
static Node seg[SEG_SIZE];

// ---------- adjacency helpers ----------
static inline void add_edge(int u,int v){
    ++ecnt;
    to[ecnt] = v;
    nxt[ecnt] = head[u];
    head[u] = ecnt;
}

// ---------- iterative dfs to compute sz, parent, depth, initJ/I, heavy ----------
static void dfs_sz_iter(int root){
    // stack: pair(node, state). state 0 = enter, 1 = exit
    static int st_node[MAXN];
    static char st_state[MAXN];
    int sp = 0;
    st_node[sp] = root; st_state[sp] = 0; sp++;
    parentv[root] = 0;
    depthv[root] = 0;
    while(sp){
        --sp;
        int u = st_node[sp];
        char state = st_state[sp];
        if(state == 0){
            // enter
            st_node[sp] = u; st_state[sp] = 1; sp++;
            // push children
            for(int e = head[u]; e; e = nxt[e]){
                int v = to[e];
                if(v == parentv[u]) continue;
                parentv[v] = u;
                depthv[v] = depthv[u] + 1;
                st_node[sp] = v; st_state[sp] = 0; sp++;
            }
        } else {
            // exit: compute sz, aggregate initJ/initI
            szArr[u] = 1;
            initJ[u] = (Fcol[u] == 0) ? 1 : 0;
            initI[u] = (Fcol[u] == 2) ? 1 : 0;
            int best = -1;
            int bestsz = 0;
            for(int e = head[u]; e; e = nxt[e]){
                int v = to[e];
                if(v == parentv[u]) continue;
                szArr[u] += szArr[v];
                initJ[u] += initJ[v];
                initI[u] += initI[v];
                if(szArr[v] > bestsz){ bestsz = szArr[v]; best = v; }
            }
            heavy[u] = best;
        }
    }
}

// ---------- iterative HLD decomposition ----------
static void dfs_hld_iter(int root){
    curPos = 0;
    // stack of chain heads to start
    static int st[MAXN];
    int sp = 0;
    st[sp++] = root;
    while(sp){
        int u = st[--sp];
        int h = u;
        // go down heavy chain
        for(int v = u; v != -1; v = heavy[v]){
            headH[v] = h;
            pos[v] = ++curPos;
            baseNode[curPos] = v;
            // push light children to stack for later processing
            for(int e = head[v]; e; e = nxt[e]){
                int c = to[e];
                if(c == parentv[v] || c == heavy[v]) continue;
                st[sp++] = c;
            }
        }
    }
}

// ---------- segment tree helpers ----------
static inline Node mergeNode(const Node &a, const Node &b){
    Node c;
    c.cnt = a.cnt + b.cnt;
    c.sumJ = a.sumJ + b.sumJ;
    c.sumI = a.sumI + b.sumI;
    c.sumJI = a.sumJI + b.sumJI;
    c.cntOme = a.cntOme + b.cntOme;
    c.sumJ_ome = a.sumJ_ome + b.sumJ_ome;
    c.sumI_ome = a.sumI_ome + b.sumI_ome;
    c.sumJI_ome = a.sumJI_ome + b.sumJI_ome;
    c.cntParentOme = a.cntParentOme + b.cntParentOme;
    c.sumJ_parent = a.sumJ_parent + b.sumJ_parent;
    c.sumI_parent = a.sumI_parent + b.sumI_parent;
    c.sumJI_parent = a.sumJI_parent + b.sumJI_parent;
    c.addJ = c.addI = 0;
    return c;
}

static inline void apply_add_node(Node &nd, ll dJ, ll dI){
    if(dJ==0 && dI==0) return;
    nd.sumJI += dJ * nd.sumI + dI * nd.sumJ + dJ * dI * nd.cnt;
    nd.sumJ  += dJ * nd.cnt;
    nd.sumI  += dI * nd.cnt;
    nd.sumJI_ome += dJ * nd.sumI_ome + dI * nd.sumJ_ome + dJ * dI * nd.cntOme;
    nd.sumJ_ome  += dJ * nd.cntOme;
    nd.sumI_ome  += dI * nd.cntOme;
    nd.sumJI_parent += dJ * nd.sumI_parent + dI * nd.sumJ_parent + dJ * dI * nd.cntParentOme;
    nd.sumJ_parent  += dJ * nd.cntParentOme;
    nd.sumI_parent  += dI * nd.cntParentOme;
    nd.addJ += dJ;
    nd.addI += dI;
}

static inline void push_down(int idx){
    ll dJ = seg[idx].addJ;
    ll dI = seg[idx].addI;
    if(dJ!=0 || dI!=0){
        apply_add_node(seg[idx<<1], dJ, dI);
        apply_add_node(seg[idx<<1|1], dJ, dI);
        seg[idx].addJ = seg[idx].addI = 0;
    }
}

static void build_seg(int idx,int l,int r){
    if(l==r){
        int u = baseNode[l];
        Node &nd = seg[idx];
        nd.cnt = 1;
        ll j = initJ[u];
        ll i = initI[u];
        nd.sumJ = j;
        nd.sumI = i;
        nd.sumJI = j * i;
        nd.cntOme = isOme[u];
        nd.sumJ_ome = isOme[u] ? j : 0;
        nd.sumI_ome = isOme[u] ? i : 0;
        nd.sumJI_ome = isOme[u] ? j * i : 0;
        nd.cntParentOme = isParentOme[u];
        nd.sumJ_parent = isParentOme[u] ? j : 0;
        nd.sumI_parent = isParentOme[u] ? i : 0;
        nd.sumJI_parent = isParentOme[u] ? j * i : 0;
        nd.addJ = nd.addI = 0;
        return;
    }
    int mid = (l + r) >> 1;
    build_seg(idx<<1, l, mid);
    build_seg(idx<<1|1, mid+1, r);
    seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}

static void range_add(int idx,int l,int r,int ql,int qr,ll dJ,ll dI){
    if(ql>r || qr<l) return;
    if(ql<=l && r<=qr){
        apply_add_node(seg[idx], dJ, dI);
        return;
    }
    push_down(idx);
    int mid = (l + r) >> 1;
    if(ql <= mid) range_add(idx<<1, l, mid, ql, qr, dJ, dI);
    if(qr > mid) range_add(idx<<1|1, mid+1, r, ql, qr, dJ, dI);
    seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}

static void point_set_ome(int idx,int l,int r,int p,int val){
    if(l==r){
        seg[idx].cntOme = val;
        seg[idx].sumJ_ome = val ? seg[idx].sumJ : 0;
        seg[idx].sumI_ome = val ? seg[idx].sumI : 0;
        seg[idx].sumJI_ome = val ? seg[idx].sumJI : 0;
        return;
    }
    push_down(idx);
    int mid = (l+r)>>1;
    if(p<=mid) point_set_ome(idx<<1, l, mid, p, val);
    else point_set_ome(idx<<1|1, mid+1, r, p, val);
    seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}

static void point_set_parentOme(int idx,int l,int r,int p,int val){
    if(l==r){
        seg[idx].cntParentOme = val;
        seg[idx].sumJ_parent = val ? seg[idx].sumJ : 0;
        seg[idx].sumI_parent = val ? seg[idx].sumI : 0;
        seg[idx].sumJI_parent = val ? seg[idx].sumJI : 0;
        return;
    }
    push_down(idx);
    int mid = (l+r)>>1;
    if(p<=mid) point_set_parentOme(idx<<1, l, mid, p, val);
    else point_set_parentOme(idx<<1|1, mid+1, r, p, val);
    seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}

// getters
static inline ll get_sumJ(){ return seg[1].sumJ; }
static inline ll get_sumI(){ return seg[1].sumI; }
static inline ll get_sumJI(){ return seg[1].sumJI; }
static inline int get_cntOme(){ return seg[1].cntOme; }
static inline ll get_sumJ_ome(){ return seg[1].sumJ_ome; }
static inline ll get_sumI_ome(){ return seg[1].sumI_ome; }
static inline ll get_sumJI_ome(){ return seg[1].sumJI_ome; }
static inline ll get_sumJI_parent(){ return seg[1].sumJI_parent; }

// add (dJ,dI) on path root->v
static inline void path_add_root(int v, ll dJ, ll dI){
    while(v != 0){
        int h = headH[v];
        range_add(1, 1, nGlobal, pos[h], pos[v], dJ, dI);
        v = parentv[h];
    }
}

// ---------- public API (matches joitour.h) ----------
void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q){
    nGlobal = N;
    // reset adjacency
    ecnt = 0;
    for(int i=1;i<=nGlobal;i++) head[i] = 0;
    // fill Fcol (1-based)
    for(int i=0;i<nGlobal;i++) Fcol[i+1] = F[i];
    // build edges (input is 0-based -> convert to 1-based)
    for(int i=0;i<nGlobal-1;i++){
        int u = U[i] + 1;
        int v = V[i] + 1;
        add_edge(u,v);
        add_edge(v,u);
    }
    // compute sizes, heavy, parent, depth, initJ/initI
    dfs_sz_iter(1);
    // HLD positions
    dfs_hld_iter(1);
    // totals
    Jtot = 0; Itot = 0;
    for(int i=1;i<=nGlobal;i++){
        if(Fcol[i]==0) ++Jtot;
        else if(Fcol[i]==2) ++Itot;
    }
    // omelette flags
    for(int i=1;i<=nGlobal;i++) isOme[i] = (Fcol[i]==1);
    for(int i=1;i<=nGlobal;i++){
        int p = parentv[i];
        isParentOme[i] = (p!=0 && Fcol[p]==1);
    }
    // build segment tree base
    build_seg(1,1,nGlobal);
}

void change(int X, int Y){
    int x = X + 1; // convert
    int old = Fcol[x];
    if(old == Y) return;
    int dJ = 0, dI = 0;
    if(old==0) --dJ;
    if(old==2) --dI;
    if(Y==0) ++dJ;
    if(Y==2) ++dI;
    if(dJ!=0 || dI!=0){
        path_add_root(x, dJ, dI);
        Jtot += dJ;
        Itot += dI;
    }
    // update omelette flag and children parent flags
    if(old==1 && Y!=1){
        isOme[x] = 0;
        point_set_ome(1,1,nGlobal,pos[x],0);
        for(int e=head[x]; e; e=nxt[e]){
            int v = to[e];
            if(parentv[v] == x){
                isParentOme[v] = 0;
                point_set_parentOme(1,1,nGlobal,pos[v],0);
            }
        }
    } else if(old!=1 && Y==1){
        isOme[x] = 1;
        point_set_ome(1,1,nGlobal,pos[x],1);
        for(int e=head[x]; e; e=nxt[e]){
            int v = to[e];
            if(parentv[v] == x){
                isParentOme[v] = 1;
                point_set_parentOme(1,1,nGlobal,pos[v],1);
            }
        }
    }
    Fcol[x] = Y;
}

long long num_tours(){
    // answer = sum_{u:F[u]==1} (Jtot*Itot - S[u])
    ll cntO = get_cntOme();
    ll term1 = get_sumJI_parent(); // sum of subJ*subI for nodes whose parent is omelette
    // term2 = sum_{u in ome} (Jtot - subJ[u])*(Itot - subI[u])
    // expand = cntO*Jtot*Itot - Jtot*sumI_ome - Itot*sumJ_ome + sumJI_ome
    ll term2 = cntO * (Jtot * Itot) - Jtot * get_sumI_ome() - Itot * get_sumJ_ome() + get_sumJI_ome();
    ll sumS_ome = term1 + term2;
    ll ans = cntO * (Jtot * Itot) - sumS_ome;
    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...