#include "tree.h"
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> ii;
const int N = 2e5 + 5;
bool SUB_4, SUB_5;
int n;
vector<int> p;
vector<int> v[N];
vector<int> w;
void init(vector<int> P, vector<int> W) {
p = P;
w = W;
n = (int)p.size();
for(int i = 1; i < n; i++) {
v[p[i] + 1].push_back(i + 1);
}
int mnw = *min_element(w.begin(), w.end());
int mxw = *max_element(w.begin(), w.end());
if(mnw == 1 and mxw == 1) {
SUB_4 = true;
}
else if(mxw == 1) {
SUB_5 = true;
}
}
int L, R;
ll coef[N], sub[N], id[N], sum[N];
multiset<tuple<ll, ll, int>> s[N];
void dfs(int x) {
if (v[x].empty()) {
coef[x] = L;
sub[x] = L;
return;
}
sub[x] = 0;
coef[x] = 0;
sum[x] = 0;
for(auto u : v[x]) {
dfs(u);
sub[x] += sub[u];
if(s[id[u]].size() > s[id[x]].size()) { // assert sz[u] <= sz[x]
swap(id[u], id[x]);
}
for(auto& e : s[id[u]]) {
s[id[x]].insert(e);
sum[id[x]] += get<1>(e);
}
while((int) s[id[x]].size() and get<0>(*s[id[x]].rbegin()) >= w[x - 1]) {
sum[id[x]] -= get<1>(*s[id[x]].rbegin());
s[id[x]].erase(prev(s[id[x]].end()));
}
}
ll can = sub[x] - L;
if(can > sum[id[x]]) {
s[id[x]].insert({w[x - 1], can - sum[id[x]], x});
sum[id[x]] = can;
}
auto add = [&](int u, int v, ll c) {
while(u != v) {
sub[u] -= c;
u = p[u - 1] + 1;
}
sub[u] -= c;
};
auto get = [&](int u, int v) {
ll ret = 1e18;
// while(u != v) {
// ret = min(ret, sub[u] - L);
// u = p[u - 1] + 1;
// }
ret = min(ret, sub[u] - L);
return ret;
};
while(sub[x] > R) {
assert(!s[id[x]].empty());
auto [w, c, u] = *s[id[x]].begin();
// c = get(u, x);
s[id[x]].erase(s[id[x]].begin());
ll need = sub[x] - R;
sum[id[x]] -= min(c, need);
if(c <= need) {
// add(u, x, c);
coef[u] -= c;
sub[x] -= c;
// coef[u] -= c;
}
else {
// add(u, x, need);
coef[u] -= need;
sub[x] -= need;
// coef[u] -= need;
s[id[x]].insert({w, c - need, u});
}
}
}
bool INIT_4 = false;
int leaf;
void init4() {
INIT_4 = true;
for(int i = 1; i <= n; i++) {
if(v[i].empty()) {
leaf++;
}
}
}
ll solve4() {
if(!INIT_4) {
init4();
}
ll total = (ll) leaf * L;
ll res = total;
if(total > R) {
res += total - R;
}
return res;
}
bool INIT_5 = false;
int mx;
int cnt[N];
ll suf_cnt[N], suf_cnt_times_i[N];
int dfs2(int x) {
if(!v[x].size()) {
return 1;
}
int res = 0;
for(auto u : v[x]) {
res += dfs2(u);
}
if(!w[x - 1]) {
res = 1;
}
else {
if(p[x - 1] == -1 or !w[p[x - 1]]) {
cnt[res]++;
mx = max(mx, res);
}
}
return res;
}
int leaf1, leaf2;
void init5() {
INIT_5 = true;
leaf2 = dfs2(1);
for(int i = 1; i <= n; i++) {
if(v[i].empty() and w[i - 1]) {
leaf1++;
}
}
for(int i = n; i >= 1; i--) {
suf_cnt[i] = suf_cnt[i + 1] + cnt[i];
suf_cnt_times_i[i] = suf_cnt_times_i[i + 1] + (ll) cnt[i] * i;
}
}
ll solve5() {
if(!INIT_5) {
init5();
}
ll res = (ll) leaf1 * L;
int l = 1, r = n;
while(l < r) {
int m = (l + r + 1) / 2;
if((ll) m * L > R) {
l = m;
}
else {
r = m - 1;
}
}
res += suf_cnt_times_i[l] * L - suf_cnt[l] * R;
// for(int i = 1; i <= mx; i++) {
// if((ll) i * L > R) {
// res += cnt[i] * ((ll) i * L - R);
// }
// }
return res;
}
ll query(int LL, int RR) {
L = LL;
R = RR;
// if(SUB_4) {
// return solve4();
// }
if(SUB_4 or SUB_5) {
return solve5();
}
for(int i = 1; i <= n; i++) {
id[i] = i;
sum[i] = 0;
s[i].clear();
}
dfs(1);
ll ans = 0;
for(int i = 1; i <= n; i++) {
// printf("coef[%d] = %lld\n", i, coef[i]);
ans += (ll) abs(coef[i]) * w[i - 1];
}
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |