Submission #1269480

#TimeUsernameProblemLanguageResultExecution timeMemory
1269480MateiKing80JOI tour (JOI24_joitour)C++20
0 / 100
0 ms408 KiB
#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

// Optimized solution using Heavy-Light Decomposition + segment tree with lazy
// Maintains for each node u: subJ[u], subI[u] = counts of color 0 and 2 in subtree u.
// We support path updates (root->x) adding +/-1 to subJ/subI when a node's color changes.
// Segment tree over HLD base-array stores aggregated sums needed to compute the global answer:
//   - sums of subJ, subI, subJ*subI over arbitrary ranges
//   - sums restricted to nodes that are omelette (F==1)
//   - sums restricted to nodes whose parent is omelette (used for children contributions)
// Using these we can compute total T = sum_{u: F[u]==1} S[u], and answer = cntOme * Jtot * Itot - T

static int nGlobal, qGlobal;
static vector<vector<int>> g;
static vector<int> Fcol; // 0/1/2
static vector<int> parentv, depthv, heavy, head, pos, sz;
static int curPos;

// Initial subtree counts
static vector<int> initJ, initI; // counts of color 0 and 2 in subtree
static ll Jtot = 0, Itot = 0;

// For parent-flag and omelette flag maintenance
static vector<int> isOme; // 1 if F==1
static vector<int> isParentOme; // for node v: is its parent an omelette?

// Segment tree
struct Node {
    int cnt;
    ll sumJ, sumI, sumJI; // sum of subJ, subI, subJ*subI
    int cntOme;
    ll sumJ_ome, sumI_ome, sumJI_ome; // restricted to omelette nodes
    int cntParentOme;
    ll sumJ_parent, sumI_parent, sumJI_parent; // restricted to nodes whose parent is omelette
    ll addJ, addI; // lazy
    Node(): cnt(0), sumJ(0), sumI(0), sumJI(0), cntOme(0), sumJ_ome(0), sumI_ome(0), sumJI_ome(0), cntParentOme(0), sumJ_parent(0), sumI_parent(0), sumJI_parent(0), addJ(0), addI(0) {}
};

static vector<Node> seg;

void dfs_sz(int u,int p){
    parentv[u]=p; depthv[u]=(p==0?0:depthv[p]+1);
    sz[u]=1; heavy[u]=-1;
    initJ[u] = (Fcol[u]==0);
    initI[u] = (Fcol[u]==2);
    for(int v: g[u]) if(v!=p){
        dfs_sz(v,u);
        initJ[u] += initJ[v];
        initI[u] += initI[v];
        if(heavy[u]==-1 || sz[v] > sz[heavy[u]]) heavy[u]=v;
        sz[u] += sz[v];
    }
}

void dfs_hld(int u,int h){
    head[u]=h; pos[u]=++curPos;
    if(heavy[u]!=-1) dfs_hld(heavy[u], h);
    for(int v: g[u]) if(v!=parentv[u] && v!=heavy[u]) dfs_hld(v, v);
}

Node mergeNode(const Node &a, const Node &b){
    Node c;
    c.cnt = a.cnt + b.cnt;
    c.sumJ = a.sumJ + b.sumJ;
    c.sumI = a.sumI + b.sumI;
    c.sumJI = a.sumJI + b.sumJI;
    c.cntOme = a.cntOme + b.cntOme;
    c.sumJ_ome = a.sumJ_ome + b.sumJ_ome;
    c.sumI_ome = a.sumI_ome + b.sumI_ome;
    c.sumJI_ome = a.sumJI_ome + b.sumJI_ome;
    c.cntParentOme = a.cntParentOme + b.cntParentOme;
    c.sumJ_parent = a.sumJ_parent + b.sumJ_parent;
    c.sumI_parent = a.sumI_parent + b.sumI_parent;
    c.sumJI_parent = a.sumJI_parent + b.sumJI_parent;
    return c;
}

void apply_add(Node &nd, ll dJ, ll dI){
    if(dJ==0 && dI==0) return;
    // update sumJI: new sumJI = sum((J+dJ)*(I+dI)) = sumJI + dJ*sumI + dI*sumJ + dJ*dI*cnt
    nd.sumJI += dJ * nd.sumI + dI * nd.sumJ + dJ * dI * nd.cnt;
    nd.sumJ += dJ * nd.cnt;
    nd.sumI += dI * nd.cnt;
    // restricted to omelette
    nd.sumJI_ome += dJ * nd.sumI_ome + dI * nd.sumJ_ome + dJ * dI * nd.cntOme;
    nd.sumJ_ome += dJ * nd.cntOme;
    nd.sumI_ome += dI * nd.cntOme;
    // restricted to parent-ome
    nd.sumJI_parent += dJ * nd.sumI_parent + dI * nd.sumJ_parent + dJ * dI * nd.cntParentOme;
    nd.sumJ_parent += dJ * nd.cntParentOme;
    nd.sumI_parent += dI * nd.cntParentOme;
    nd.addJ += dJ;
    nd.addI += dI;
}

void push(int idx){
    ll dJ = seg[idx].addJ, dI = seg[idx].addI;
    if(dJ!=0 || dI!=0){
        apply_add(seg[idx<<1], dJ, dI);
        apply_add(seg[idx<<1|1], dJ, dI);
        seg[idx].addJ = seg[idx].addI = 0;
    }
}

void build(int idx,int l,int r, const vector<int> &baseNode){
    if(l==r){
        int u = baseNode[l];
        seg[idx].cnt = 1;
        ll j = initJ[u];
        ll i = initI[u];
        seg[idx].sumJ = j;
        seg[idx].sumI = i;
        seg[idx].sumJI = j * i;
        seg[idx].cntOme = isOme[u];
        seg[idx].sumJ_ome = isOme[u] ? j : 0;
        seg[idx].sumI_ome = isOme[u] ? i : 0;
        seg[idx].sumJI_ome = isOme[u] ? j*i : 0;
        seg[idx].cntParentOme = isParentOme[u];
        seg[idx].sumJ_parent = isParentOme[u] ? j : 0;
        seg[idx].sumI_parent = isParentOme[u] ? i : 0;
        seg[idx].sumJI_parent = isParentOme[u] ? j*i : 0;
        seg[idx].addJ = seg[idx].addI = 0;
    } else {
        int mid=(l+r)>>1;
        build(idx<<1,l,mid,baseNode);
        build(idx<<1|1,mid+1,r,baseNode);
        seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
    }
}

void range_add(int idx,int l,int r,int ql,int qr,ll dJ,ll dI){
    if(ql>r || qr<l) return;
    if(ql<=l && r<=qr){
        apply_add(seg[idx], dJ, dI);
        return;
    }
    push(idx);
    int mid=(l+r)>>1;
    range_add(idx<<1,l,mid,ql,qr,dJ,dI);
    range_add(idx<<1|1,mid+1,r,ql,qr,dJ,dI);
    seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}

// point update for toggling omelette flag at position p (set to val 0/1)
void point_set_ome(int idx,int l,int r,int p,int val){
    if(l==r){
        seg[idx].cntOme = val;
        // sumJ_ome etc reflect current subJ (which include lazy adds)
        seg[idx].sumJ_ome = val ? seg[idx].sumJ : 0;
        seg[idx].sumI_ome = val ? seg[idx].sumI : 0;
        seg[idx].sumJI_ome = val ? seg[idx].sumJI : 0;
        return;
    }
    push(idx);
    int mid=(l+r)>>1;
    if(p<=mid) point_set_ome(idx<<1,l,mid,p,val);
    else point_set_ome(idx<<1|1,mid+1,r,p,val);
    seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}

// point update for toggling parent-ome flag at position p (set to val 0/1)
void point_set_parentOme(int idx,int l,int r,int p,int val){
    if(l==r){
        seg[idx].cntParentOme = val;
        seg[idx].sumJ_parent = val ? seg[idx].sumJ : 0;
        seg[idx].sumI_parent = val ? seg[idx].sumI : 0;
        seg[idx].sumJI_parent = val ? seg[idx].sumJI : 0;
        return;
    }
    push(idx);
    int mid=(l+r)>>1;
    if(p<=mid) point_set_parentOme(idx<<1,l,mid,p,val);
    else point_set_parentOme(idx<<1|1,mid+1,r,p,val);
    seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}

// Query whole tree aggregates
inline ll get_sumJ(){ return seg[1].sumJ; }
inline ll get_sumI(){ return seg[1].sumI; }
inline ll get_sumJI(){ return seg[1].sumJI; }
inline int get_cntOme(){ return seg[1].cntOme; }
inline ll get_sumJ_ome(){ return seg[1].sumJ_ome; }
inline ll get_sumI_ome(){ return seg[1].sumI_ome; }
inline ll get_sumJI_ome(){ return seg[1].sumJI_ome; }
inline ll get_sumJI_parent(){ return seg[1].sumJI_parent; }

// HLD path update: add dJ,dI to subJ/subI for all nodes on path root(=1) -> v
void path_add_root(int v, ll dJ, ll dI, int n){
    while(v!=0){
        int h = head[v];
        range_add(1,1,n,pos[h], pos[v], dJ, dI);
        v = parentv[h];
    }
}

// Build base node ordering array (pos->node)

// Exposed API

void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q){
    nGlobal = N; qGlobal = Q;
    g.assign(N+1, {});
    Fcol.assign(N+1,0);
    for(int i=1;i<=N;i++) Fcol[i]=F[i-1];
    for(int i=0;i<N-1;i++){
        int u=U[i], v=V[i];
        g[u].push_back(v); g[v].push_back(u);
    }
    parentv.assign(N+1,0); depthv.assign(N+1,0); heavy.assign(N+1,-1); head.assign(N+1,0); pos.assign(N+1,0); sz.assign(N+1,0);
    initJ.assign(N+1,0); initI.assign(N+1,0);
    curPos = 0;
    dfs_sz(1,0);
    dfs_hld(1,1);
    // compute initial totals
    Jtot = 0; Itot = 0;
    for(int i=1;i<=N;i++){
        if(Fcol[i]==0) Jtot++;
        else if(Fcol[i]==2) Itot++;
    }
    // prepare isOme and isParentOme
    isOme.assign(N+1,0);
    isParentOme.assign(N+1,0);
    for(int i=1;i<=N;i++) isOme[i] = (Fcol[i]==1);
    for(int i=1;i<=N;i++){
        if(parentv[i]!=0) isParentOme[i] = (Fcol[parentv[i]]==1);
    }
    // build baseNode: mapping from pos->node
    vector<int> baseNode(N+1);
    for(int i=1;i<=N;i++) baseNode[pos[i]] = i;
    // segment tree
    seg.assign(4*(N+5), Node());
    build(1,1,N,baseNode);
}

void change(int X, int Y){
    int old = Fcol[X];
    if(old == Y) return;
    // update global Jtot/Itot and isOme, isParentOme
    int dJ = 0, dI = 0;
    if(old==0) dJ -= 1;
    if(old==2) dI -= 1;
    if(Y==0) dJ += 1;
    if(Y==2) dI += 1;
    // 1) Before changing subJ/subI values, we will apply path updates to subJ/subI
    // For a change at node X, all ancestors (including X) have their subJ/subI changed by dJ/dI.
    if(dJ!=0 || dI!=0){
        path_add_root(X, dJ, dI, nGlobal);
        Jtot += dJ; Itot += dI;
    }
    // 2) Update omelette flag at X if needed
    if(old==1 && Y!=1){
        // turning off omelette
        isOme[X]=0;
        point_set_ome(1,1,nGlobal,pos[X],0);
        // update children nodes' parent flags (they stop seeing parent as omelette)
        for(int v: g[X]) if(parentv[v]==X){
            isParentOme[v]=0;
            point_set_parentOme(1,1,nGlobal,pos[v],0);
        }
    } else if(old!=1 && Y==1){
        // turning on omelette
        isOme[X]=1;
        point_set_ome(1,1,nGlobal,pos[X],1);
        // update children parent flags
        for(int v: g[X]) if(parentv[v]==X){
            isParentOme[v]=1;
            point_set_parentOme(1,1,nGlobal,pos[v],1);
        }
    }
    // 3) If parent of X changed its omelette status? handled when that parent was changed in its own change call
    // 4) Update Fcol
    Fcol[X] = Y;
}

long long num_tours(){
    // answer = sum_{u:F[u]==1} (Jtot*Itot - S[u])
    // where S[u] = sum_children Pedge[child] + (Jtot - subJ[u])*(Itot - subI[u])
    // sum_{u in ome} S[u] = sum_childrenP_for_ome (which equals seg.sumJI_parent) + sum_{u in ome} (Jtot - subJ[u])*(Itot - subI[u])
    // second term expands: cntOme*Jtot*Itot - Jtot*sumI_ome - Itot*sumJ_ome + sumJI_ome
    ll cntO = get_cntOme();
    ll term1 = get_sumJI_parent();
    ll term2 = cntO * (Jtot * Itot) - Jtot * get_sumI_ome() - Itot * get_sumJ_ome() + get_sumJI_ome();
    ll sumS_ome = term1 + term2;
    ll ans = cntO * (Jtot * Itot) - sumS_ome;
    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...