#include <bits/stdc++.h>
#include "tree.h"
using namespace std;
typedef long long ll;
ll n;
vector<int> p, w;
vector<vector<ll>> adj;
vector<vector<pair<ll, ll>>> jmp; // id, min
vector<ll> lla;
vector<ll> rngId, lastId;
ll idk;
struct seg_tree {
ll pow2;
vector<ll> seg;
seg_tree(ll n) {
pow2 = 1ll << (ll)ceil(log2(n));
seg = vector<ll>(2*pow2);
}
ll sum(ll l, ll r, ll k, ll tl, ll tr) {
if (tl > r || tr < l) return 0;
if (tl >= l && tr <= r) return seg[k];
ll c = (tl + tr) / 2;
return sum(l, r, k*2, tl, c) + sum(l, r, k*2|1, c+1, tr);
}
void point(ll pos, ll x) {
pos += pow2;
seg[pos] = x;
for (pos /= 2; pos >= 1; pos /= 2) {
seg[pos] = seg[pos*2] + seg[pos*2|1];
}
}
};
void dsf(ll cur) {
rngId[cur] = idk++;
for (auto &e : adj[cur]) {
dsf(e);
}
lastId[cur] = idk-1;
}
void init(vector<int> P, vector<int> W) {
p = P;
w = W;
n = (int)p.size();
rngId = lastId = vector<ll>(n);
adj = vector<vector<ll>>(n);
for (ll i = 1; i < n; i++) {
adj[p[i]].push_back(i);
}
jmp = vector<vector<pair<ll, ll>>>(n, vector<pair<ll, ll>>(20));
jmp[0][0] = {0, 1ll << 62ll};
for (ll i = 1; i < n; i++) jmp[i][0] = {p[i], w[p[i]]};
for (ll j = 1; j < 20; j++) {
for (ll i = 0; i < n; i++) {
jmp[i][j] = {jmp[jmp[i][j-1].first][j-1].first, min(jmp[i][j-1].second, jmp[jmp[i][j-1].first][j-1].second)};
}
}
lla = vector<ll>(n);
for (ll i = 0; i < n; i++) {
ll cur = i;
for (ll j = 19; j >= 0; j--) {
if (jmp[cur][j].second > w[i]) cur = jmp[cur][j].first;
}
lla[i] = jmp[cur][0].first;
}
dsf(idk = 0);
}
ll L, R;
ll result;
void dfs(ll cur, vector<priority_queue<pair<ll, ll>>> &vq, vector<ll> &processed, seg_tree &tree) {
if (adj[cur].empty()) {
tree.point(rngId[cur], L);
result += w[cur] * L;
processed[cur] = 1;
return;
}
for (auto node : adj[cur]) {
dfs(node, vq, processed, tree);
}
for (auto node : adj[cur]) {
if (vq[node].size() > vq[cur].size()) {
swap(vq[node], vq[cur]);
}
while (vq[node].size()) {
auto t = vq[node].top();
vq[node].pop();
vq[cur].push(t);
}
}
vq[cur].push({-w[cur], cur});
while (tree.sum(rngId[cur], lastId[cur], 1, 0, tree.pow2-1) > R) {
// check if top of pq is valid
auto [w1, i] = vq[cur].top();
w1 = -w1;
if (processed[lla[i]] || tree.sum(rngId[i], lastId[i], 1, 0, tree.pow2-1) == L) {
vq[cur].pop();
continue;
}
// Remove untill value of that is L or value of cur is R
ll remove_count = min(tree.sum(rngId[i], lastId[i], 1, 0, tree.pow2-1) - L, tree.sum(rngId[cur], lastId[cur], 1, 0, tree.pow2-1) - R);
tree.point(rngId[i], tree.seg[tree.pow2 + rngId[i]] - remove_count);
result += w1 * remove_count;
}
processed[cur] = 1;
}
ll query(int _L, int _R) {
L = _L; R = _R; result = 0;
vector<priority_queue<pair<ll, ll>>> vq(n); // w, i
vector<ll> processed(n);
seg_tree tree(n);
dfs(0, vq, processed, tree);
return result;
}
#ifdef TEST
#include "grader.cpp"
#endif
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |