Submission #1210656

#TimeUsernameProblemLanguageResultExecution timeMemory
1210656trimkusTree (IOI24_tree)C++20
31 / 100
1438 ms31940 KiB
#include "tree.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

static const int MAXN = 60000;

int N;
vector<int> W;
vector<vector<int>> adj;
vector<int> P;

int sz[MAXN], heavy[MAXN], depth[MAXN];
int head[MAXN], pos[MAXN], cur_pos;
vector<int> base;

int tin[MAXN], tout[MAXN];
int timerDFS;

struct SegTree {
    int n;
    vector<ll> st, lazy;

    SegTree(int _n = 0) {
        init(_n);
    }

    void init(int _n) {
        n = _n;
        st.assign(4 * n, 0LL);
        lazy.assign(4 * n, 0LL);
    }

    void push(int idx, int L, int R) {
        if (lazy[idx] != 0 && L < R) {
            ll val = lazy[idx];
            int mid = (L + R) >> 1;
            st[idx<<1]   += val;
            lazy[idx<<1] += val;
            st[idx<<1|1]   += val;
            lazy[idx<<1|1] += val;
            lazy[idx] = 0;
        } else if (L == R) {
            lazy[idx] = 0;
        }
    }

    void update_range(int idx, int L, int R, int i, int j, ll val) {
        if (i > R || j < L) return;
        if (i <= L && R <= j) {
            st[idx]   += val;
            lazy[idx] += val;
            return;
        }
        push(idx, L, R);
        int mid = (L + R) >> 1;
        update_range(idx<<1,   L,    mid, i, j, val);
        update_range(idx<<1|1, mid+1, R,   i, j, val);
        st[idx] = min(st[idx<<1], st[idx<<1|1]);
    }

    ll query_min(int idx, int L, int R, int i, int j) {
        if (i > R || j < L) return LLONG_MAX;
        if (i <= L && R <= j) {
            return st[idx];
        }
        push(idx, L, R);
        int mid = (L + R) >> 1;
        return min(
            query_min(idx<<1,   L,    mid, i, j),
            query_min(idx<<1|1, mid+1, R,   i, j)
        );
    }

    void update_range(int l, int r, ll val) {
        if (l > r) return;
        update_range(1, 0, n-1, l, r, val);
    }
    ll query_min(int l, int r) {
        if (l > r) return LLONG_MAX;
        return query_min(1, 0, n-1, l, r);
    }
};

SegTree seg;


int dfs_size(int v) {
    sz[v] = 1;
    heavy[v] = -1;
    int maxSize = 0;
    for (int u : adj[v]) {
        if (u == P[v]) continue;
        depth[u] = depth[v] + 1;
        int childSize = dfs_size(u);
        if (childSize > maxSize) {
            maxSize = childSize;
            heavy[v] = u;
        }
        sz[v] += childSize;
    }
    return sz[v];
}

// 2) Second DFS: assign head[] and pos[], fill up base[]
void decompose(int v, int h) {
    head[v] = h;
    pos[v]  = cur_pos;
    base[cur_pos++] = 0;  // we store 0 as the initial sum at each node

    if (heavy[v] != -1) {
        // Continue the same chain
        decompose(heavy[v], h);
    }
    // Any child that is not heavy starts a new chain
    for (int u : adj[v]) {
        if (u == P[v] || u == heavy[v]) continue;
        decompose(u, u);
    }
}

void init(vector<int> _P, vector<int> _W) {
    P = _P;
    P[0] = -1;

    W = _W;
    N = (int)P.size();

    adj.assign(N, {});
    for (int i = 1; i < N; i++) {
        adj[P[i]].push_back(i);
        adj[i].push_back(P[i]);
    }

    timerDFS = 1;
    function<void(int,int)> dfsEuler = [&](int v, int p) {
        tin[v] = timerDFS++;
        for (int u : adj[v]) {
            if (u == p) continue;
            dfsEuler(u, v);
        }
        tout[v] = timerDFS++;
    };
    depth[0] = 0;
    dfsEuler(0, -1);

    depth[0] = 0;
    dfs_size(0);

    cur_pos = 0;
    base.resize(N);
    decompose(0, 0);

}


// Add +delta to all nodes on the path from v up to root (0).
void path_update(int v, ll delta) {
    while (head[v] != head[0]) {
        seg.update_range(pos[ head[v] ], pos[v], delta);
        v = P[ head[v] ];
    }
    seg.update_range(pos[0], pos[v], delta);
}

// Query the minimum value among all ancestors of v (including v).
ll path_query_min(int v) {
    ll res = LLONG_MAX;
    while (head[v] != head[0]) {
        res = min(res, seg.query_min(pos[ head[v] ], pos[v]));
        v = P[ head[v] ];
    }
    res = min(res, seg.query_min(pos[0], pos[v]));
    return res;
}

ll get_sum(int v) {
    return seg.query_min(pos[v], pos[v]);
}
priority_queue<pair<int, int>> leaves[MAXN];
vector<int> ptr(MAXN);
// bool inq[MAXN];
int degree[MAXN];

long long query(int L, int R) {
    cerr << "Query " << L << " " << R << "\n";
    seg.init(N);

    ll res = 0;
    vector<bool> inq(N, false);
    // vector<int> ptr(N);
    iota(begin(ptr), end(ptr), 0);

    for (int i = 1; i < N; i++) {
        while (leaves[i].size()) leaves[i].pop();
        degree[i] = (int)adj[i].size() - 1;
    }
    while (leaves[0].size()) leaves[0].pop();
    degree[0] = (int)adj[0].size();

    queue<int> q;
    for (int i = 1; i < N; i++) {
        if (degree[i] == 0) {
            q.push(i);
            ll delta = L;
            path_update(i, delta);
            res += 1LL * W[i] * delta;
        }
    }

    while (!q.empty()) {
        int v = q.front(); q.pop();
        // assert(!inq[v]);
        // inq[v] = true;
        int ptrv = ptr[v];
        leaves[ptrv].push({ -W[v], v });
        ll sum = get_sum(v);
        while (!leaves[ptrv].empty() && sum > R) {
            int A = leaves[ptrv].top().second;

            ll mnAlongPath = path_query_min(A);
            if (mnAlongPath == L) {
                leaves[ptrv].pop();
                continue;
            }
            ll canSubtract = sum - R;
            ll maxSubAtA   = mnAlongPath - L;
            ll delta       = min(canSubtract, maxSubAtA);

            path_update(A, -delta);
            res += 1LL * W[A] * delta;
            sum -= delta;
            ll newMin = mnAlongPath - delta;
            if (newMin == L) {
                leaves[ptrv].pop();
            }
        }

        int u = P[v];
        if (u == -1) continue;
        if (leaves[ptr[v]].size() > leaves[ptr[u]].size()) swap(ptr[v], ptr[u]);
        ptrv = ptr[v];
        int ptru = ptr[u];
        while (!leaves[ptrv].empty()) {
            leaves[ptru].push(leaves[ptrv].top());
            leaves[ptrv].pop();
        }
        if (--degree[u] == 0) {
            q.push(u);
        }
    }

    for (int i = 0; i < N; i++) {
        ll sv = get_sum(i);
        assert(sv >= L && sv <= R);
    }

    return res;
}

#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...