#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using pll = pair<ll, ll>;
#define pb push_back
#define ff first
#define ss second
#define arr3 array<ll, 3>
const int N = 1e9 + 5;
struct DS{
struct node{
node *l, *r;
ll x, y;
int s;
node(){
l = r = 0;
x = y = s = 0;
}
};
vector<arr3> vv;
node *rt;
void init(){
rt = new node();
}
void fn(node *v, int tl, int tr){
if (!(v -> s)) return;
if (tl == tr){
vv.pb({tl, v -> x, v -> y});
return;
}
int tm = (tl + tr) / 2;
if (v -> l) fn(v -> l, tl, tm);
if (v -> r) fn(v -> r, tm + 1, tr);
}
void fn(){
vv.clear();
fn(rt, 1, N);
}
int size(){
return rt -> s;
}
void recalc(node *v){
v -> x = v -> y = v -> s = 0;
if (v -> l){
v -> x += (v -> l -> x);
v -> y += (v -> l -> y);
v -> s += (v -> l -> s);
}
if (v -> r){
v -> x += (v -> r -> x);
v -> y += (v -> r -> y);
v -> s += (v -> r -> s);
}
}
void add(node *v, int tl, int tr, int& p, pll& x){
if (tl == tr){
v -> x += x.ff;
v -> y += x.ss;
v -> s = 1;
return;
}
int tm = (tl + tr) / 2;
if (p <= tm){
if (!(v -> l)) v -> l = new node();
add(v -> l, tl, tm, p, x);
}
else {
if (!(v -> r)) v -> r = new node();
add(v -> r, tm + 1, tr, p, x);
}
recalc(v);
}
void add(int p, pll x){
add(rt, 1, N, p, x);
}
void clear(node *&v, int tl, int tr, int& p){
if (tl > p) return;
if (tr <= p){
v = 0;
return;
}
int tm = (tl + tr) / 2;
if (v -> l) clear(v -> l, tl, tm, p);
if (v -> r) clear(v -> r, tm + 1, tr, p);
recalc(v);
}
void clear(int x){
clear(rt, 1, N, x);
}
pll get(node *v, int tl, int tr, int& l, int& r){
if (l > tr || r < tl) return {0, 0};
if (l <= tl && tr <= r) return {v -> x, v -> y};
int tm = (tl + tr) / 2;
pll ret = {0, 0};
if (v -> l){
pll s = get(v -> l, tl, tm, l, r);
ret.ff += s.ff; ret.ss += s.ss;
}
if (v -> r){
pll s = get(v -> r, tm + 1, tr, l, r);
ret.ff += s.ff; ret.ss += s.ss;
}
return ret;
}
pll get(int l, int r){
return get(rt, 1, N, l, r);
}
};
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int n; cin>>n;
vector<int> a(n + 1), h(n + 1), c(n + 1), g[n + 1];
for (int i = 1; i <= n; i++){
cin>>a[i]>>h[i]>>c[i];
if (a[i] != i) g[a[i]].pb(i);
}
DS T[n + 1];
vector<ll> dp(n + 1), dp1(n + 1), S(n + 1);
function<void(int, int)> solve = [&](int v, int pr){
S[v] = c[v];
T[v].init();
ll sum = 0;
for (int i: g[v]){
if (i == pr) continue;
solve(i, v);
S[v] += S[i];
sum += dp[i];
if (T[v].size() < T[i].size()) swap(T[v], T[i]);
T[i].fn();
for (auto [p, x, y]: T[i].vv){
T[v].add((int) p, {x, y});
}
}
auto [x, y] = T[v].get(1, h[v]);
dp1[v] = x + (S[v] - c[v] - y);
assert(x >= 0 && ((S[v] - c[v] - y) >= 0));
dp[v] = min(dp1[v], c[v] + sum);
pll g = T[v].get(1, h[v] + 1);
T[v].clear(h[v] + 1);
T[v].add(1, {dp1[v], S[v]});
T[v].add(h[v] + 1, {g.ff - dp1[v], g.ss - S[v]});
};
solve(1, 0);
cout<<dp[1]<<"\n";
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
344 KB |
Output is correct |
2 |
Correct |
0 ms |
348 KB |
Output is correct |
3 |
Correct |
0 ms |
348 KB |
Output is correct |
4 |
Correct |
0 ms |
348 KB |
Output is correct |
5 |
Incorrect |
20 ms |
21508 KB |
Output isn't correct |
6 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
344 KB |
Output is correct |
2 |
Correct |
0 ms |
348 KB |
Output is correct |
3 |
Correct |
0 ms |
348 KB |
Output is correct |
4 |
Correct |
0 ms |
348 KB |
Output is correct |
5 |
Incorrect |
20 ms |
21508 KB |
Output isn't correct |
6 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
344 KB |
Output is correct |
2 |
Correct |
0 ms |
348 KB |
Output is correct |
3 |
Correct |
0 ms |
348 KB |
Output is correct |
4 |
Correct |
0 ms |
348 KB |
Output is correct |
5 |
Incorrect |
20 ms |
21508 KB |
Output isn't correct |
6 |
Halted |
0 ms |
0 KB |
- |