Submission #1151209

#TimeUsernameProblemLanguageResultExecution timeMemory
1151209ZicrusTree (IOI24_tree)C++20
41 / 100
2097 ms136296 KiB
#include <bits/stdc++.h>
#include "tree.h"
using namespace std;

typedef long long ll;

ll n;
vector<int> p, w;
vector<vector<ll>> adj;
vector<vector<pair<ll, ll>>> jmp; // id, min
vector<ll> lla;
vector<ll> rngId, lastId;
ll idk;

struct seg_tree {
    ll pow2;
    vector<ll> seg;

    seg_tree(ll n) {
        pow2 = 1ll << (ll)ceil(log2(n));
        seg = vector<ll>(2*pow2);
    }

    ll sum(ll l, ll r, ll k, ll tl, ll tr) {
        if (tl > r || tr < l) return 0;
        if (tl >= l && tr <= r) return seg[k];
        ll c = (tl + tr) / 2;
        return sum(l, r, k*2, tl, c) + sum(l, r, k*2|1, c+1, tr);
    }

    void point(ll pos, ll x) {
        pos += pow2;
        seg[pos] = x;
        for (pos /= 2; pos >= 1; pos /= 2) {
            seg[pos] = seg[pos*2] + seg[pos*2|1];
        }
    }
};

void dsf(ll cur) {
    rngId[cur] = idk++;
    for (auto &e : adj[cur]) {
        dsf(e);
    }
    lastId[cur] = idk-1;
}

void init(vector<int> P, vector<int> W) {
    p = P;
    w = W;
    n = (int)p.size();
    rngId = lastId = vector<ll>(n);
    adj = vector<vector<ll>>(n);
    for (ll i = 1; i < n; i++) {
        adj[p[i]].push_back(i);
    }
    jmp = vector<vector<pair<ll, ll>>>(n, vector<pair<ll, ll>>(20));
    jmp[0][0] = {0, 1ll << 62ll};
    for (ll i = 1; i < n; i++) jmp[i][0] = {p[i], w[p[i]]};
    for (ll j = 1; j < 20; j++) {
        for (ll i = 0; i < n; i++) {
            jmp[i][j] = {jmp[jmp[i][j-1].first][j-1].first, min(jmp[i][j-1].second, jmp[jmp[i][j-1].first][j-1].second)};
        }
    }
    lla = vector<ll>(n);
    for (ll i = 0; i < n; i++) {
        ll cur = i;
        for (ll j = 19; j >= 0; j--) {
            if (jmp[cur][j].second > w[i]) cur = jmp[cur][j].first;
        }
        lla[i] = jmp[cur][0].first;
    }
    dsf(idk = 0);
}

ll L, R;
ll result;

void dfs(ll cur, vector<priority_queue<pair<ll, ll>>> &vq, vector<ll> &processed, seg_tree &tree) {
    if (adj[cur].empty()) {
        tree.point(rngId[cur], L);
        result += w[cur] * L;
        processed[cur] = 1;
        return;
    }
    for (auto node : adj[cur]) {
        dfs(node, vq, processed, tree);
    }
    for (auto node : adj[cur]) {
        if (vq[node].size() > vq[cur].size()) {
            swap(vq[node], vq[cur]);
        }
        while (vq[node].size()) {
            auto t = vq[node].top();
            vq[node].pop();
            vq[cur].push(t);
        }
    }
    vq[cur].push({-w[cur], cur});
    while (tree.sum(rngId[cur], lastId[cur], 1, 0, tree.pow2-1) > R) {
        // check if top of pq is valid
        auto [w1, i] = vq[cur].top();
        w1 = -w1;
        if (processed[lla[i]] || tree.sum(rngId[i], lastId[i], 1, 0, tree.pow2-1) == L) {
            vq[cur].pop();
            continue;
        }

        // Remove untill value of that is L or value of cur is R
        ll remove_count = min(tree.sum(rngId[i], lastId[i], 1, 0, tree.pow2-1) - L, tree.sum(rngId[cur], lastId[cur], 1, 0, tree.pow2-1) - R);
        tree.point(rngId[i], tree.seg[tree.pow2 + rngId[i]] - remove_count);
        result += w1 * remove_count;
    }
    processed[cur] = 1;
}

ll query(int _L, int _R) {
    L = _L; R = _R; result = 0;
    vector<priority_queue<pair<ll, ll>>> vq(n); // w, i
    vector<ll> processed(n);
    seg_tree tree(n);

    dfs(0, vq, processed, tree);
    return result;
}

#ifdef TEST
#include "grader.cpp"
#endif
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...