제출 #1212842

#제출 시각아이디문제언어결과실행 시간메모리
1212842vanea트리 (IOI24_tree)C++20
0 / 100
2095 ms60628 KiB
#include <bits/stdc++.h>
#include "tree.h"
using namespace std;
using ll = long long;

const int mxN = 5e5+10;
const int INF = 1e9+10;

vector<int> adj[mxN];
ll w[mxN], ans;
ll cnt[mxN];
int timer = -1, tin[mxN], tout[mxN];
array<ll, 3> st[mxN];

array<ll, 3> comb(array<ll, 3> a, array<ll, 3> b) {
    array<ll, 3> ans;
    if(a[0] == b[0]) {
        ans[0] = a[0];
        ans[1] = a[1]+b[1];
        ans[2] = a[2];
    }
    else if(a[0] > b[0]) ans = b;
    else ans = a;
    return ans;
}

void build(int node, int l, int r) {
    if(l == r) {
        st[node] = {INF, 0, l};
        return ;
    }
    int mid = (l+r)/2;
    build(node*2, l, mid);
    build(node*2+1, mid+1, r);
    st[node] = comb(st[node*2], st[node*2+1]);
}

void upd(int node, int l, int r, int k, array<ll, 3> now) {
    if(l == r && l == k) {
        st[node] = now;
        return ;
    }
    if(l > k || r < k) return;
    int mid = (l+r)/2;
    upd(node*2, l, mid, k, now);
    upd(node*2+1, mid+1, r, k, now);
    st[node] = comb(st[node*2], st[node*2+1]);
}

array<ll, 3> qry(int node, int l, int r, int l1, int r1) {
    if(l1 <= l && r <= r1) return st[node];
    if(l1 > r || r1 < l) return {0, 0, 0};
    int mid = (l+r)/2;
    return comb(qry(node*2, l, mid, l1, r1), qry(node*2+1, mid+1, r, l1, r1));
}

void dfs(int node, int p, ll l, ll r) {
    bool leaf = true;
    for(auto it : adj[node]) {
        if(it == p) continue;
        dfs(it, node, l, r);
        cnt[node] += cnt[it];
        leaf = false;
    }
    if(leaf) {
        ans += l * w[node];
        cnt[node] = l;
        return ;
    }
    array<ll, 3> mn = qry(1, 0, timer+1, tin[node], tout[node]);
    while(mn[1] != 0 && mn[0] < w[node] && cnt[node] > r) {
        ll now = cnt[node]-mn[1];
        if(now <= r) {
            mn[1] -= (cnt[node]-r);
            ans += (cnt[node]-r)*mn[0];
            if(mn[1] == 0) mn[0] = INF;
            upd(1, 0, timer+1, mn[2], mn);
            cnt[node] = r;
        }
        else {
            cnt[node] -= mn[1];
            ans += mn[1]*mn[0];
            upd(1, 0, timer+1, mn[2], {INF, 0, mn[2]});
        }
    }
    if(cnt[node] > r) {
        ans += (cnt[node]-r)*w[node];
        cnt[node] = r;
    }
    upd(1, 0, timer+1, tin[node], {w[node], r-l, tin[node]});
}

void dfs1(int node, int p) {
    tin[node] = ++timer;
    for(auto it : adj[node]) {
        if(it == p) continue;
        dfs1(it, node);
    }
    tout[node] = timer;
}

void init(vector<int> P, vector<int> W) {
    int n = P.size();
    w[0] = W[0];
    for(int i = 1; i < n; i++) {
        adj[P[i]].push_back(i);
        adj[i].push_back(P[i]);
        w[i] = W[i];
    }
    dfs1(0, -1);
}

ll query(int l, int r) {
    ans = 0;
    build(1, 0, timer+1);
    dfs(0, -1, (ll)l, (ll)r);
    return ans;
}

/*int main()
{
    init({-1, 0, 0, 2, 2, 2}, {1, 10, 5, 4, 3, 7});
    for(int i = 0; i < 6; i++) {
        cout << tin[i] << ' ' << tout[i] << '\n';
    }
    cout << query(1, 3);
}*/
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...