#include <bits/stdc++.h>
#include "tree.h"
using namespace std;
using ll = long long;
const int mxN = 2e5+10;
vector<int> adj[mxN];
ll w[mxN];
int tin[mxN], tout[mxN], timer = -1, n, leaves = 0;
bool one;
void dfs(int node, int p) {
tin[node] = ++timer;
if(adj[node].size() == 0) ++leaves;
for(auto it : adj[node]) {
if(it == p) continue;
dfs(it, node);
}
tout[node] = timer;
}
struct LazySeg {
int n;
vector<ll> tree, lz;
LazySeg(int n) {
this->n = n;
tree = vector<ll>(4 * n);
lz = vector<ll>(4 * n);
}
void init() {
tree = vector<ll>(4 * n);
lz = vector<ll>(4 * n);
}
void push(int i) {
tree[i * 2] += lz[i];
tree[i * 2 + 1] += lz[i];
lz[i * 2] += lz[i];
lz[i * 2 + 1] += lz[i];
lz[i] = 0;
}
void update(int i, int l, int r, int ql, int qr, ll delta) {
if (l > qr || r < ql) return;
if (ql <= l && r <= qr) {
tree[i] += delta;
lz[i] += delta;
return;
}
push(i);
int m = (l + r) / 2;
if (ql <= m) update(i * 2, l, m, ql, qr, delta);
if (qr > m) update(i * 2 + 1, m + 1, r, ql, qr, delta);
tree[i] = tree[i * 2] + tree[i * 2 + 1];
}
void update(int ql, int qr, ll delta) {
update(1, 1, n, ql, qr, delta);
}
ll query(int i, int l, int r, int ql, int qr) {
if (l > qr || r < ql) return 0;
if (ql <= l && r <= qr) return tree[i];
push(i);
int m = (l + r) / 2;
return query(i * 2, l, m, ql, qr) + query(i * 2 + 1, m + 1, r, ql, qr);
}
ll query(int ql, int qr) {
return query(1, 1, n, ql, qr);
}
} sums(2 * mxN), del(2 * mxN);
void init(vector<int> P, vector<int> W) {
n = P.size();
one = true;
for(int i = 0; i < n; i++) {
w[i] = W[i];
if(w[i] != 1) one = false;
}
for(int i = 1; i < n; i++) {
adj[P[i]].push_back(i);
}
dfs(0, -1);
}
ll query(int L, int R) {
if(one) {
ll ans = (ll)leaves * (ll)L + max(0LL, (ll)leaves * L - R);
return ans;
}
sums.init();
del.init();
ll ans = 0;
vector<set<array<int, 2>>> mn(n+1);
auto dfs = [&](auto& dfs, int node) {
if(adj[node].size() == 0) {
ans += 1LL * L * w[node];
sums.update(tin[node], tin[node], L);
return ;
}
mn[node].insert({w[node], node});
for (auto& u : adj[node]) {
dfs(dfs, u);
if (mn[u].size() > mn[node].size()) swap(mn[u], mn[node]);
for (auto& val : mn[u]) mn[node].insert(val);
mn[u].clear();
}
ll cnt = sums.query(tin[node], tout[node]);
while(cnt > R) {
int u = (*(mn[node].begin()))[1];
int deleted = del.query(tin[u], tin[u]);
if(deleted > 0) {
mn[node].erase(mn[node].begin());
continue;
}
ll need = min(cnt-R, sums.query(tin[u], tout[u]) - L);
cnt -= need;
sums.update(tin[u], tin[u], -need);
ans += need * w[u];
if(sums.query(tin[u], tout[u]) == L) {
mn[node].erase(mn[node].begin());
del.update(tin[u], tout[u], 1);
}
}
};
dfs(dfs, 0);
return ans;
}
Compilation message (stderr)
tree.cpp: In instantiation of 'query(int, int)::<lambda(auto:47&, int)> [with auto:47 = query(int, int)::<lambda(auto:47&, int)>]':
tree.cpp:123:8: required from here
tree.cpp:98:32: warning: narrowing conversion of 'w[node]' from 'll' {aka 'long long int'} to 'int' [-Wnarrowing]
98 | mn[node].insert({w[node], node});
| ~~~~~~^
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |