#include <bits/stdc++.h>
#include "tree.h"
using namespace std;
using ll = long long;
const int mxN = 2e5+10;
const int INF = 1e9+10;
bool one;
vector<int> adj[mxN];
int n, timer = -1, l, r, leaves = 0;
int tin[mxN], tout[mxN], sz[mxN];
ll ans = 0, w[mxN], cnt[mxN], first = 0;
array<int, 2> st[4*mxN];
void dfs(int node, int p) {
tin[node] = ++timer;
bool leaf = true;
for(auto it : adj[node]) {
if(it == p) continue;
leaf = false;
dfs(it, node);
sz[node] += sz[it];
}
if(!leaf) first += sz[node];
leaves += leaf;
sz[node] += leaf;
tout[node] = timer;
}
void build(int node, int l, int r) {
if(l == r) {
st[node] = {INF, 0};
return ;
}
int mid = (l+r)/2;
build(node*2, l, mid); build(node*2+1, mid+1, r);
st[node] = min(st[node*2], st[node*2+1]);
}
void upd(int node, int l, int r, int k, array<int, 2> x) {
if(l == r && l == k) {
st[node] = x;
return ;
}
if(l > k || r < k) return ;
int mid = (l+r)/2;
upd(node*2, l, mid, k, x);
upd(node*2+1, mid+1, r, k, x);
st[node] = min(st[node*2], st[node*2+1]);
}
array<int, 3> qry(int node, int l, int r, int l1, int r1) {
if(l1 <= l && r <= r1) {
array<int, 3> now = {st[node][0], st[node][1], l};
return now;
}
if(l1 > r || r1 < l) return {INF, 0, 0};
int mid = (l+r)/2;
return min(qry(node*2, l, mid, l1, r1), qry(node*2+1, mid+1, r, l1, r1));
}
void dfs1(int node, int p) {
bool leaf = true;
for(auto it : adj[node]) {
if(it == p) continue;
leaf = false;
dfs1(it, node);
cnt[node] += cnt[it];
}
if(leaf) {
ans += l*w[node];
cnt[node] = l;
return ;
}
while(cnt[node] > r) {
array<int, 3> now = qry(1, 0, timer+1, tin[node], tout[node]);
if(now[0] >= w[node]) {
ans += (cnt[node] - r) * w[node];
cnt[node] = r;
break;
}
int nxt = cnt[node] - now[1];
if(nxt <= r) {
nxt = (cnt[node]-r); cnt[node] -= nxt;
now[1] -= nxt; ans += now[0] * nxt;
upd(1, 0, timer+1, now[2], {now[0], now[1]});
}
else {
cnt[node] -= now[1]; ans += now[0] * now[1];
upd(1, 0, timer+1, now[2], {INF, 0});
}
}
upd(1, 0, timer+1, tin[node], {w[node], r-l});
}
void init(vector<int> P, vector<int> W) {
n = P.size();
one = true;
for(int i = 0; i < n; i++) {
w[i] = W[i];
if(w[i] != 1) one = false;
}
for(int i = 1; i < n; i++) {
adj[P[i]].push_back(i);
}
dfs(0, -1);
}
ll query(int L, int R) {
ans = 0; l = L; r = R;
if(one) {
ans = max(0LL, first * L - (R * (n - leaves)));
ans += L * leaves;
}
else {
build(1, 0, timer+1);
for(int i = 0; i < n; i++) cnt[i] = 0;
dfs1(0, -1);
}
return ans;
}
/*int main()
{
init({-1, 0, 0}, {1, 1, 1});
//init({-1, 0, 0, 1, 1, 1, 2, 2, 2}, {2, 3, 1, 4, 4, 3, 3, 3, 1});
cout << query(1, 1) << ' ' << query(1, 2) << ' ' << query(3, 5);
cout << ' ' << query(5, 10);
}*/
Compilation message (stderr)
tree.cpp: In function 'void dfs1(int, int)':
tree.cpp:94:42: warning: narrowing conversion of 'w[node]' from 'll' {aka 'long long int'} to 'int' [-Wnarrowing]
94 | upd(1, 0, timer+1, tin[node], {w[node], r-l});
| ~~~~~~^
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |