Submission #1292619

#TimeUsernameProblemLanguageResultExecution timeMemory
1292619MinbaevFactories (JOI14_factories)C++20
Compilation error
0 ms0 KiB
#include "factories.h"
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define ar array

const int NN = 500000 + 5;

struct Bit {
    int sz;
    vector<long long> t, bn;
    Bit (int n = 0){
        this->sz = n;
        t.assign(sz * 4 + 5, 0);
        bn.assign(sz * 4 + 5, -1LL);
    }

    void push(int tl, int tr, int v){
        if(bn[v] == -1LL) return;
        if(tl != tr){
            bn[v*2] = bn[v];
            bn[v*2+1] = bn[v];
        }
        t[v] = (long long)(tr - tl + 1) * bn[v];
        bn[v] = -1LL;
    }

    void update(int tl, int tr, int v, int l, int r, long long val){
        push(tl, tr, v);
        if(r < l || tr < l || r < tl) return;
        if(l <= tl && tr <= r){
            bn[v] = val;
            push(tl, tr, v);
            return;
        }
        int tm = (tl + tr) >> 1;
        update(tl, tm, v*2, l, r, val);
        update(tm+1, tr, v*2+1, l, r, val);
        t[v] = t[v*2] + t[v*2+1];
    }

    void upd(int l, int r, int x){
        if(l > r) return;
        update(1, sz, 1, l, r, (long long)x);
    }

    long long gt(int tl, int tr, int v, int l, int r){
        push(tl, tr, v);
        if(r < l || tr < l || r < tl) return 0LL;
        if(l <= tl && tr <= r){
            return t[v];
        }
        int tm = (tl + tr) >> 1;
        long long a = gt(tl, tm, v*2, l, r);
        long long b = gt(tm+1, tr, v*2+1, l, r);
        t[v] = t[v*2] + t[v*2+1];
        return a + b;
    }

    long long get(int l, int r){
        if(l > r) return 0LL;
        return gt(1, sz, 1, l, r);
    }
};

Bit t; // will initialize in Init

// graph and HLD arrays
static vector<pair<int,int>> g[NN];
static vector<int> id(NN), head(NN), id_rev(NN), parent_v(NN), sz_v(NN);
static vector<long long> dep(NN);
static int timer_glob = 0;
static int Nglob = 0;

void dfs(int x, int pr){
    parent_v[x] = pr;
    sz_v[x] = 1;
    for(auto &ed : g[x]){
        int to = ed.first; int w = ed.second;
        if(to == pr) continue;
        dep[to] = dep[x] + (long long)w;
        dfs(to, x);
        sz_v[x] += sz_v[to];
    }
}

void hld(int x, int pr){
    int heavy = -1;
    for(auto &ed : g[x]){
        int to = ed.first;
        if(to == pr) continue;
        if(heavy == -1 || sz_v[to] > sz_v[heavy]) heavy = to;
    }

    timer_glob++;
    id[x] = timer_glob;
    id_rev[timer_glob] = x;

    if(heavy != -1){
        head[heavy] = head[x];
        hld(heavy, x);
    }
    for(auto &ed : g[x]){
        int to = ed.first;
        if(to == pr || to == heavy) continue;
        head[to] = to;
        hld(to, x);
    }
}

// find top-most node of component containing x when edges with value 1 form connectivity
int up(int x){
    // if edge parent->x is active, node x is in middle; but we seek topmost node u such that edges on path u..x are all 1
    while(true){
        int h = head[x];
        int hid = id[h];
        int xid = id[x];
        // edges on this chain correspond to positions hid+1 .. xid
        if(hid + 1 <= xid){
            long long ones = t.get(hid+1, xid);
            long long need = (long long)(xid - hid); // number of edges in chain
            if(ones == need){
                // whole chain active, go to parent of head
                if(parent_v[h] == -1) return h;
                x = parent_v[h];
                continue;
            } else {
                // binary search leftmost position p in [hid+1 .. xid] such that edges p..xid are all ones
                int L = hid + 1, R = xid;
                while(L < R){
                    int M = (L + R) >> 1;
                    long long s = t.get(M, xid);
                    long long need2 = (long long)(xid - M + 1);
                    if(s == need2) R = M;
                    else L = M + 1;
                }
                // p = L, top node is id_rev[p-1]
                int pnode = id_rev[L - 1];
                return pnode;
            }
        } else {
            // no edges on this chain (x is head)
            if(parent_v[h] == -1) return h;
            x = parent_v[h];
            continue;
        }
    }
}

// set edges on path from node up to heads to val (1/0) - iterative
void set_path_up(int x, int val){
    while(true){
        int h = head[x];
        int hid = id[h];
        int xid = id[x];
        if(hid + 1 <= xid) t.upd(hid+1, xid, val);
        if(parent_v[h] == -1) break;
        x = parent_v[h];
    }
}

void Init(int N, int A[], int B[], int D[]){
    Nglob = N;
    // clear graph arrays for N nodes
    for(int i = 0; i < Nglob; ++i){
        g[i].clear();
        id[i] = head[i] = id_rev[i] = parent_v[i] = sz_v[i] = 0;
        dep[i] = 0;
    }
    timer_glob = 0;

    // read exactly N-1 edges from given arrays
    for(int i = 0; i < Nglob - 1; ++i){
        int a = A[i], b = B[i], d = D[i];
        g[a].pb({b, d});
        g[b].pb({a, d});
    }

    // root at 0
    parent_v[0] = -1;
    dep[0] = 0;
    dfs(0, -1);
    head[0] = 0;
    hld(0, -1);

    // init segtree with size N (positions 1..N)
    t = Bit(Nglob);
}

long long Query(int S, int X[], int T, int Y[]){
    // mark paths from each X[i] up to chain heads (set edges to 1)
    for(int i = 0; i < S; ++i){
        set_path_up(X[i], 1);
    }

    // for each Y, compute its up() and distance
    vector<pair<int,long long>> vs;
    vs.reserve(T);
    for(int i = 0; i < T; ++i){
        int u = up(Y[i]);
        long long dist = dep[Y[i]] - dep[u];
        vs.emplace_back(u, dist);
    }

    // unmark X paths
    for(int i = 0; i < S; ++i){
        set_path_up(X[i], 0);
    }

    // reduce vs to minimal per u
    sort(vs.begin(), vs.end());
    vector<pair<int,long long>> compact;
    for(size_t i = 0; i < vs.size(); ++i){
        if(i + 1 == vs.size() || vs[i+1].first != vs[i].first){
            compact.push_back(vs[i]);
        } else {
            if(vs[i+1].second > vs[i].second) vs[i+1].second = vs[i].second;
        }
    }

    // put minimal values in seg tree at id[u] (temporary)
    for(auto &pr : compact){
        int u = pr.first;
        long long val = pr.second;
        t.upd(id[u], id[u], (int)val); // store as integer value (fits problem constraints)
    }

    long long res = LLONG_MAX;
    for(int i = 0; i < S; ++i){
        int u = up(X[i]);
        long long cell = t.get(id[u], id[u]);
        if(cell == 0) continue;
        long long cand = (dep[X[i]] - dep[u]) + cell;
        if(cand < res) res = cand;
    }

    // clear temporary marks
    for(auto &pr : compact){
        int u = pr.first;
        t.upd(id[u], id[u], 0);
    }

    if(res == LLONG_MAX) return -1;
    return res;
}