#include "tree.h"
#include<bits/stdc++.h>
using namespace std;
#define F first
#define S second
#define pll pair<ll, ll>
#define vll vector<ll>
#define pb push_back
typedef long long ll;
namespace{
const ll mxN=2e5+5;
const ll inf=1e18;
const ll mxM=1e6+5;
struct segtreeadd{
vll tree;
ll siz;
void init(ll sz){
siz=sz+1;
tree=vll(siz, 0);
}
void update(ll idx, ll val){
idx++;
for(;idx<siz;idx+=(idx&(-idx))){
tree[idx]+=val;
}
}
ll getpre(ll idx){
idx++;
ll re=0;
for(;idx>0;idx-=(idx&(-idx))){
re+=tree[idx];
}
return re;
}
ll getsum(ll qlow, ll qhigh){
return getpre(qhigh)-getpre(qlow-1);
}
};
struct segtreemin{
vector<pll> tree;
ll treelen;
void init(ll sz){
treelen=sz+1;
while(__builtin_popcount(treelen)!=1) treelen++;
tree=vector<pll>(2*treelen, {inf, inf});
}
void update(ll idx, pll val){
ll tar=idx+treelen;
tree[tar]=val;
tar/=2;
while(tar>0){
tree[tar]=min(tree[2*tar], tree[2*tar+1]);
tar/=2;
}
}
pll getmin1(ll idx, ll low, ll high, ll qlow, ll qhigh){
if(low>=qlow && high<=qhigh){
return tree[idx];
}
if(low>qhigh || high<qlow){
return {inf, inf};
}
ll mid=(low+high)/2;
return min(getmin1(2*idx, low, mid, qlow, qhigh),
getmin1(2*idx+1, mid+1, high, qlow, qhigh));
}
pll getmin(ll qlow, ll qhigh){
return getmin1(1, 0, treelen-1, qlow, qhigh);
}
};
segtreeadd seg1;
segtreemin seg2;
ll n, l, r;
ll p[mxN];
ll w[mxN];
vll adj[mxN];
ll in[mxN], out[mxN];
ll timer;
ll ans;
vector<array<ll, 5>> eq;
ll leafsum;
ll prel[mxM], prer[mxM];
void dfs(ll cur){
in[cur]=timer;
out[cur]=timer;
timer++;
for(auto &chd:adj[cur]){
dfs(chd);
out[cur]=out[chd];
}
}
void chainupd(ll t, ll b, ll val){
seg1.update(in[b], val);
if(p[t]!=-1){
seg1.update(in[p[t]], -val);
}
}
void f(ll cur){
pll tep=seg2.getmin(in[cur], out[cur]);
if(tep.F==inf) return;
ll tar=tep.S;
// cout<<seg1.getsum(in[cur], out[cur])<<'\n';
ll total=seg1.getsum(in[cur], out[cur]);
// ll mx=seg1.getsum(in[cur], out[cur])*l-r;
// ll lef=0;
// ll org=seg1.getsum(in[tar], out[tar])*l;
ll th=seg1.getsum(in[tar], out[tar]);
eq.pb({total-th+1, 0, -w[tar], 0, tar});
eq.pb({total, total-th+1, (total-th)*w[tar], -w[tar], tar});
// ll mntar=max(org-mx, l);
for(auto &chd:adj[tar]){
ll k=seg1.getsum(in[chd], out[chd]);
// ll val=min(seg1.getsum(in[chd], out[chd])*l, r);
// lef+=val;
eq.pb({total, k, k*w[tar], 0, tar});
eq.pb({k, 0, 0, w[tar], tar});
}
// ll minu=lef-mntar;
// if(mx>0) ans+=w[tar]*minu;
// cout<<"node: "<<tar<<' '<<lef<<' '<<mx<<'\n';
chainupd(cur, tar, -(th-1));
// cout<<"updating "<<cur<<' '<<tar<<' '<<-lef<<'\n';
seg2.update(in[tar], {inf, inf});
for(auto &chd:adj[tar]){
f(chd);
}
f(cur);
}
}
void init(vector<int> P, vector<int> W) {
n=P.size();
for(ll i=0;i<n;i++){
p[i]=P[i];
w[i]=W[i];
if(i>0){
adj[p[i]].pb(i);
}
}
timer=0;
dfs(0);
leafsum=0;
seg1.init(n);
seg2.init(n);
for(ll i=0;i<n;i++){
if(in[i]==out[i]){
seg1.update(in[i], 1);
seg2.update(in[i], {inf, inf});
leafsum+=w[i];
}
else{
seg2.update(in[i], {w[i], i});
}
}
f(0);
for(auto &it:eq){
prel[it[1]]+=it[2];
prel[it[0]]-=it[2];
prer[it[1]]+=it[3];
prer[it[0]]-=it[3];
}
for(ll i=1;i<mxM;i++){
prel[i]+=prel[i-1];
prer[i]+=prer[i-1];
}
}
long long query(int L, int R) {
ans=0;
l=L;
r=R;
ans+=leafsum*l;
ll divi=r/l;
ans+=prel[divi]*l+prer[divi]*r;
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |