#include <bits/stdc++.h>
#include "tree.h"
using namespace std;
#ifdef DEBUG
#include "debug.h"
#else
#define debug(...) void(37)
#endif
struct DSU {
vector<int> link, leaves;
DSU(vector<int> _leaf_count) {
leaves = _leaf_count;
int n = int(leaves.size());
link.resize(n);
iota(link.begin(), link.end(), 0);
}
int get(int v) {
return link[v] = link[v] == v ? v : get(link[v]);
}
bool unite(int x, int y) {
x = get(x), y = get(y);
if (x == y) return false;
link[y] = x;
leaves[x] += leaves[y] - 1;
return true;
}
int leaf_count(int v) {
return leaves[get(v)];
}
};
struct linear_sum {
int64_t L_coeff, R_coeff;
int64_t eval(int L, int R) {
return L_coeff * L + R_coeff * R;
}
linear_sum init() {
L_coeff = 0, R_coeff = 0;
return *this;
}
};
linear_sum operator+(linear_sum l, linear_sum r) {
l.L_coeff += r.L_coeff;
l.R_coeff += r.R_coeff;
return l;
}
linear_sum operator-(linear_sum l, linear_sum r) {
l.L_coeff -= r.L_coeff;
l.R_coeff -= r.R_coeff;
return l;
}
linear_sum operator-(const linear_sum& x) {
return linear_sum{}.init() - x;
}
vector<linear_sum> pref_sums;
vector<int> placer;
int64_t leaf_sums;
void init(std::vector<int> P, std::vector<int> W) {
int N = int(P.size());
vector<int> degree(N);
for (int i = 1; i < N; ++i) degree[P[i]]++;
for (int i = 0; i < N; ++i) {
if (degree[i] == 0) {
degree[i] = 1;
leaf_sums += W[i];
}
}
DSU dsu(degree);
vector<int> node_ord(N); iota(node_ord.begin(), node_ord.end(), 0);
sort(node_ord.begin(), node_ord.end(), [&](int x, int y) {
return W[x] > W[y];
});
vector<vector<int>> waiting(N);
vector<bool> act(N);
vector<pair<int, linear_sum>> events;
for (auto v : node_ord) {
act[v] = true;
vector<int> sizes;
for (auto u : waiting[v]) {
sizes.push_back(dsu.leaf_count(u));
dsu.unite(v, u);
}
int subtree = dsu.leaf_count(v);
if (P[v] != -1) {
if (act[P[v]]) dsu.unite(P[v], v);
else waiting[P[v]].push_back(v);
}
int root = dsu.leaf_count(v);
auto Add = [&](int t, linear_sum x) {
x.R_coeff *= W[v], x.L_coeff *= W[v];
events.push_back({t, x});
};
debug(v, subtree, root, sizes);
//target: L + (R - (r - s + 1) * L) when R > (r - s + 1) * L, else L
auto target = linear_sum{1 - root + subtree - 1, +1};
int to_bigger_target_cut = (root - subtree + 1);
int mandatory_value_cut = subtree;
//R < subtree * L
{
Add(0, {-1, 0}); // R - L
Add(to_bigger_target_cut, -target + linear_sum{+1, 0}); //R - target
//sums - R
int non_leaf_c = int(sizes.size());
int non_leaf_leaves = accumulate(sizes.begin(), sizes.end(), 0);
Add(0, {subtree - non_leaf_leaves, non_leaf_c});
for (auto s : sizes) {
Add(s, {s, -1});
}
}
//R > subtree * L
{
//Add(mandatory_value_cut, {subtree, 0}); maintained from the top
Add(root, target - linear_sum{subtree, 0});
}
}
sort(events.begin(), events.end(), [&](auto x, auto y) {
return x.first < y.first;
});
int s = int(events.size());
pref_sums.resize(s + 1);
placer.resize(s);
for (int i = 0; i < s; ++i) {
placer[i] = events[i].first;
pref_sums[i + 1] = pref_sums[i] + events[i].second;
}
}
long long query(int L, int R) {
int p = int(lower_bound(placer.begin(), placer.end(), array<int, 2>{L, R}, [&](int x, array<int, 2> lr) {
return int64_t(lr[0]) * x < lr[1];
}) - placer.begin());
return leaf_sums * L + pref_sums[p].eval(L, R);
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |