#include <bits/stdc++.h>
#include "tree.h"
using namespace std;
using ll = long long;
const int mxN = 8e5+10;
const int INF = 1e9+10;
vector<int> adj[mxN];
ll w[mxN], ans;
ll cnt[mxN];
int timer = -1, n, tin[mxN], tout[mxN];
array<ll, 3> st[mxN];
void build(int node, int l, int r) {
if(l == r) {
st[node] = {INF, 0, l};
return ;
}
int mid = (l+r)/2;
build(node*2, l, mid);
build(node*2+1, mid+1, r);
st[node] = min(st[node*2], st[node*2+1]);
}
void upd(int node, int l, int r, int k, array<ll, 3> now) {
if(l == r && l == k) {
st[node] = now;
return ;
}
if(l > k || r < k) return;
int mid = (l+r)/2;
upd(node*2, l, mid, k, now);
upd(node*2+1, mid+1, r, k, now);
st[node] = min(st[node*2], st[node*2+1]);
}
array<ll, 3> qry(int node, int l, int r, int l1, int r1) {
if(l1 <= l && r <= r1) return st[node];
if(l1 > r || r1 < l) return {INF, 0, 0};
int mid = (l+r)/2;
return min(qry(node*2, l, mid, l1, r1), qry(node*2+1, mid+1, r, l1, r1));
}
void dfs(int node, int p, ll l, ll r) {
bool leaf = true;
for(auto it : adj[node]) {
if(it == p) continue;
dfs(it, node, l, r);
cnt[node] += cnt[it];
leaf = false;
}
if(leaf) {
ans += l * w[node];
cnt[node] = l;
return ;
}
array<ll, 3> mn = qry(1, 0, timer+1, tin[node], tout[node]);
while(mn[1] != 0 && mn[0] < w[node] && cnt[node] > r) {
ll now = cnt[node]-mn[1];
if(now <= r) {
mn[1] -= (cnt[node]-r);
ans += (cnt[node]-r)*mn[0];
if(mn[1] == 0) mn[0] = INF;
upd(1, 0, timer+1, mn[2], mn);
cnt[node] = r;
}
else {
cnt[node] -= mn[1];
ans += mn[1]*mn[0];
upd(1, 0, timer+1, mn[2], {INF, 0, mn[2]});
}
mn = qry(1, 0, timer+1, tin[node], tout[node]);
}
if(cnt[node] > r) {
ans += (cnt[node]-r)*w[node];
cnt[node] = r;
}
upd(1, 0, timer+1, tin[node], {w[node], r-l, tin[node]});
}
void dfs1(int node, int p) {
tin[node] = ++timer;
for(auto it : adj[node]) {
if(it == p) continue;
dfs1(it, node);
}
tout[node] = timer;
}
void init(vector<int> P, vector<int> W) {
n = W.size();
w[0] = W[0];
if(P[0] != -1) while(1) {};
for(int i = 1; i < n; i++) {
adj[P[i]].push_back(i);
w[i] = W[i];
}
dfs1(0, -1);
}
ll query(int l, int r) {
ans = 0;
build(1, 0, timer+1);
for(int i = 0; i < n; i++) cnt[i] = 0;
dfs(0, -1, (ll)l, (ll)r);
return ans;
}
/*int main()
{
init({-1, 0, 0}, {1, 1, 1});
for(int i = 0; i < 3; i++) {
cout << tin[i] << ' ' << tout[i] << '\n';
}
cout << query(1, 1) << ' ' << query(1, 2);
}*/
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |