Submission #1269482

#TimeUsernameProblemLanguageResultExecution timeMemory
1269482MateiKing80JOI tour (JOI24_joitour)C++20
52 / 100
3035 ms134924 KiB
#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

// Optimized solution using Heavy-Light Decomposition + segment tree
// NOTE: internal node indices are 1..N. The public API (init/change) uses
// the problem's 0-based indices, so we convert inputs to 1-based here.

static int nGlobal, qGlobal;
static vector<vector<int>> g;
static vector<int> Fcol; // 1-based: Fcol[1..N] in {0,1,2}
static vector<int> parentv, depthv, heavy, head, pos, sz;
static int curPos;

static vector<int> initJ, initI; // subtree counts of color 0 and 2
static ll Jtot = 0, Itot = 0;

static vector<int> isOme;       // node is omelette (F==1)
static vector<int> isParentOme; // node's parent is omelette

struct Node {
    int cnt;
    ll sumJ, sumI, sumJI; // sum over nodes in segment of subJ, subI, subJ*subI
    int cntOme;
    ll sumJ_ome, sumI_ome, sumJI_ome; // restricted to omelette nodes
    int cntParentOme;
    ll sumJ_parent, sumI_parent, sumJI_parent; // restricted to nodes whose parent is omelette
    ll addJ, addI; // lazy adds to subJ/subI
    Node(): cnt(0), sumJ(0), sumI(0), sumJI(0),
            cntOme(0), sumJ_ome(0), sumI_ome(0), sumJI_ome(0),
            cntParentOme(0), sumJ_parent(0), sumI_parent(0), sumJI_parent(0),
            addJ(0), addI(0) {}
};

static vector<Node> seg;

void dfs_sz(int u,int p){
    parentv[u]=p;
    depthv[u]=(p==0?0:depthv[p]+1);
    sz[u]=1; heavy[u]=-1;
    initJ[u] = (Fcol[u]==0);
    initI[u] = (Fcol[u]==2);
    for(int v: g[u]) if(v!=p){
        dfs_sz(v,u);
        initJ[u] += initJ[v];
        initI[u] += initI[v];
        if(heavy[u]==-1 || sz[v] > sz[heavy[u]]) heavy[u]=v;
        sz[u] += sz[v];
    }
}

void dfs_hld(int u,int h){
    head[u]=h; pos[u]=++curPos;
    if(heavy[u]!=-1) dfs_hld(heavy[u], h);
    for(int v: g[u]) if(v!=parentv[u] && v!=heavy[u]) dfs_hld(v, v);
}

Node mergeNode(const Node &a, const Node &b){
    Node c;
    c.cnt = a.cnt + b.cnt;
    c.sumJ = a.sumJ + b.sumJ;
    c.sumI = a.sumI + b.sumI;
    c.sumJI = a.sumJI + b.sumJI;
    c.cntOme = a.cntOme + b.cntOme;
    c.sumJ_ome = a.sumJ_ome + b.sumJ_ome;
    c.sumI_ome = a.sumI_ome + b.sumI_ome;
    c.sumJI_ome = a.sumJI_ome + b.sumJI_ome;
    c.cntParentOme = a.cntParentOme + b.cntParentOme;
    c.sumJ_parent = a.sumJ_parent + b.sumJ_parent;
    c.sumI_parent = a.sumI_parent + b.sumI_parent;
    c.sumJI_parent = a.sumJI_parent + b.sumJI_parent;
    return c;
}

void apply_add(Node &nd, ll dJ, ll dI){
    if(dJ==0 && dI==0) return;
    // sumJI <- sum((J+dJ)*(I+dI)) = sumJI + dJ*sumI + dI*sumJ + dJ*dI*cnt
    nd.sumJI += dJ * nd.sumI + dI * nd.sumJ + dJ * dI * nd.cnt;
    nd.sumJ  += dJ * nd.cnt;
    nd.sumI  += dI * nd.cnt;
    // omelette restricted
    nd.sumJI_ome += dJ * nd.sumI_ome + dI * nd.sumJ_ome + dJ * dI * nd.cntOme;
    nd.sumJ_ome  += dJ * nd.cntOme;
    nd.sumI_ome  += dI * nd.cntOme;
    // parent-ome restricted
    nd.sumJI_parent += dJ * nd.sumI_parent + dI * nd.sumJ_parent + dJ * dI * nd.cntParentOme;
    nd.sumJ_parent  += dJ * nd.cntParentOme;
    nd.sumI_parent  += dI * nd.cntParentOme;
    nd.addJ += dJ;
    nd.addI += dI;
}

void push(int idx){
    ll dJ = seg[idx].addJ, dI = seg[idx].addI;
    if(dJ!=0 || dI!=0){
        apply_add(seg[idx<<1], dJ, dI);
        apply_add(seg[idx<<1|1], dJ, dI);
        seg[idx].addJ = seg[idx].addI = 0;
    }
}

void build(int idx,int l,int r, const vector<int> &baseNode){
    if(l==r){
        int u = baseNode[l];
        seg[idx].cnt = 1;
        ll j = initJ[u];
        ll i = initI[u];
        seg[idx].sumJ = j;
        seg[idx].sumI = i;
        seg[idx].sumJI = j * i;
        seg[idx].cntOme = isOme[u];
        seg[idx].sumJ_ome = isOme[u] ? j : 0;
        seg[idx].sumI_ome = isOme[u] ? i : 0;
        seg[idx].sumJI_ome = isOme[u] ? j*i : 0;
        seg[idx].cntParentOme = isParentOme[u];
        seg[idx].sumJ_parent = isParentOme[u] ? j : 0;
        seg[idx].sumI_parent = isParentOme[u] ? i : 0;
        seg[idx].sumJI_parent = isParentOme[u] ? j*i : 0;
        seg[idx].addJ = seg[idx].addI = 0;
    } else {
        int mid=(l+r)>>1;
        build(idx<<1,l,mid,baseNode);
        build(idx<<1|1,mid+1,r,baseNode);
        seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
    }
}

void range_add(int idx,int l,int r,int ql,int qr,ll dJ,ll dI){
    if(ql>r || qr<l) return;
    if(ql<=l && r<=qr){
        apply_add(seg[idx], dJ, dI);
        return;
    }
    push(idx);
    int mid=(l+r)>>1;
    range_add(idx<<1,l,mid,ql,qr,dJ,dI);
    range_add(idx<<1|1,mid+1,r,ql,qr,dJ,dI);
    seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}

void point_set_ome(int idx,int l,int r,int p,int val){
    if(l==r){
        seg[idx].cntOme = val;
        seg[idx].sumJ_ome = val ? seg[idx].sumJ : 0;
        seg[idx].sumI_ome = val ? seg[idx].sumI : 0;
        seg[idx].sumJI_ome = val ? seg[idx].sumJI : 0;
        return;
    }
    push(idx);
    int mid=(l+r)>>1;
    if(p<=mid) point_set_ome(idx<<1,l,mid,p,val);
    else point_set_ome(idx<<1|1,mid+1,r,p,val);
    seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}

void point_set_parentOme(int idx,int l,int r,int p,int val){
    if(l==r){
        seg[idx].cntParentOme = val;
        seg[idx].sumJ_parent = val ? seg[idx].sumJ : 0;
        seg[idx].sumI_parent = val ? seg[idx].sumI : 0;
        seg[idx].sumJI_parent = val ? seg[idx].sumJI : 0;
        return;
    }
    push(idx);
    int mid=(l+r)>>1;
    if(p<=mid) point_set_parentOme(idx<<1,l,mid,p,val);
    else point_set_parentOme(idx<<1|1,mid+1,r,p,val);
    seg[idx] = mergeNode(seg[idx<<1], seg[idx<<1|1]);
}

inline ll get_sumJ(){ return seg[1].sumJ; }
inline ll get_sumI(){ return seg[1].sumI; }
inline ll get_sumJI(){ return seg[1].sumJI; }
inline int get_cntOme(){ return seg[1].cntOme; }
inline ll get_sumJ_ome(){ return seg[1].sumJ_ome; }
inline ll get_sumI_ome(){ return seg[1].sumI_ome; }
inline ll get_sumJI_ome(){ return seg[1].sumJI_ome; }
inline ll get_sumJI_parent(){ return seg[1].sumJI_parent; }

// add (dJ,dI) to all nodes on path root(=1) -> v
void path_add_root(int v, ll dJ, ll dI, int n){
    while(v!=0){
        int h = head[v];
        range_add(1,1,n,pos[h], pos[v], dJ, dI);
        v = parentv[h];
    }
}

void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q){
    nGlobal = N; qGlobal = Q;
    g.assign(N+1, {});
    Fcol.assign(N+1,0);
    // map F (0-based vector) into Fcol (1-based)
    for(int i=0;i<N;i++) Fcol[i+1] = F[i];
    // build graph: convert input edges (0-based) to 1-based
    for(int i=0;i<N-1;i++){
        int u = U[i] + 1;
        int v = V[i] + 1;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    parentv.assign(N+1,0); depthv.assign(N+1,0); heavy.assign(N+1,-1); head.assign(N+1,0); pos.assign(N+1,0); sz.assign(N+1,0);
    initJ.assign(N+1,0); initI.assign(N+1,0);
    curPos = 0;
    // root the tree at 1
    dfs_sz(1,0);
    dfs_hld(1,1);
    // totals
    Jtot = 0; Itot = 0;
    for(int i=1;i<=N;i++){
        if(Fcol[i]==0) Jtot++;
        else if(Fcol[i]==2) Itot++;
    }
    // prepare omelette and parent-ome flags
    isOme.assign(N+1,0);
    isParentOme.assign(N+1,0);
    for(int i=1;i<=N;i++) isOme[i] = (Fcol[i]==1);
    for(int i=1;i<=N;i++){
        if(parentv[i]!=0) isParentOme[i] = (Fcol[parentv[i]]==1);
    }
    // baseNode pos->node
    vector<int> baseNode(N+1);
    for(int i=1;i<=N;i++) baseNode[pos[i]] = i;
    // build segment tree
    seg.assign(4*(N+5), Node());
    build(1,1,N,baseNode);
}

void change(int X, int Y){
    // X given 0-based externally -> convert to 1-based
    int x = X + 1;
    int old = Fcol[x];
    if(old == Y) return;
    int dJ = 0, dI = 0;
    if(old==0) dJ -= 1;
    if(old==2) dI -= 1;
    if(Y==0) dJ += 1;
    if(Y==2) dI += 1;
    // update subtree counts: all ancestors (incl. x) change by dJ,dI
    if(dJ!=0 || dI!=0){
        path_add_root(x, dJ, dI, nGlobal);
        Jtot += dJ;
        Itot += dI;
    }
    // update omelette flag at x and the parent flags of its children
    if(old==1 && Y!=1){
        // turning off omelette
        isOme[x]=0;
        point_set_ome(1,1,nGlobal,pos[x],0);
        for(int v: g[x]) if(parentv[v]==x){
            isParentOme[v]=0;
            point_set_parentOme(1,1,nGlobal,pos[v],0);
        }
    } else if(old!=1 && Y==1){
        // turning on omelette
        isOme[x]=1;
        point_set_ome(1,1,nGlobal,pos[x],1);
        for(int v: g[x]) if(parentv[v]==x){
            isParentOme[v]=1;
            point_set_parentOme(1,1,nGlobal,pos[v],1);
        }
    }
    Fcol[x] = Y;
}

long long num_tours(){
    // answer = sum_{u:F[u]==1} (Jtot*Itot - S[u])
    // compute sumS_ome and then answer
    ll cntO = get_cntOme();
    ll term1 = get_sumJI_parent(); // sum over nodes whose parent is omelette: subJ*subI
    // term2 = sum_{u in ome} (Jtot - subJ[u])*(Itot - subI[u])
    // expand: cntO*Jtot*Itot - Jtot*sumI_ome - Itot*sumJ_ome + sumJI_ome
    ll term2 = cntO * (Jtot * Itot) - Jtot * get_sumI_ome() - Itot * get_sumJ_ome() + get_sumJI_ome();
    ll sumS_ome = term1 + term2;
    ll ans = cntO * (Jtot * Itot) - sumS_ome;
    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...