Submission #1245627

#TimeUsernameProblemLanguageResultExecution timeMemory
1245627Boomyday트리 (IOI24_tree)C++20
41 / 100
128 ms52120 KiB
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define int ll
const int  INF = 1e9;
int n;
#include "tree.h"
std::vector<int> p, w;
vector<vector<int>> adj;

vector<bool> is_leaf;
vector<int> LL, RR;
vector<pair<ll,ll>> special_pairs; // (w, c)
int T = 0;
ll leaf_sum = 0; // sum of weights of all leaves

// plan for the main dfs:
// - find special node,
// - recurse all children,
// - clean up node,
// - rerun dfs


int ssz = 262144;

vector<ll> seg_leaf(2 * 262144, 0);
vector<pair<ll, int>> seg_weight(2 * 262144, {INF, -1}); // value, index


vector<ll> Rx, Lx;

void add_leaf(int i, ll x) {
    i += ssz; seg_leaf[i] = x; i /= 2;
    while (i > 0) {
        seg_leaf[i] = seg_leaf[2 * i] + seg_leaf[2 * i + 1];
        i /= 2;
    }
}

int sum_leaf(int l, int r) {
    l += ssz; r += ssz; ll ans = 0;
    while (l <= r) {
        if (l % 2 == 1) ans += seg_leaf[l++];
        if (r % 2 == 0) ans += seg_leaf[r--];
        l /= 2; r /= 2;
    }
    return ans;
}

void upd_weight(int i, pair<ll, int> x) {
    i += ssz; seg_weight[i] = x; i /= 2;
    while (i > 0) {
        seg_weight[i] = min(seg_weight[2 * i], seg_weight[2 * i + 1]);
        i /= 2;
    }
}
pair<ll, int> min_val (int l, int r){
    l += ssz; r += ssz; pair<ll, int> ans = {INF, -1};
    while (l <= r) {
        if (l % 2 == 1) ans = min(ans, seg_weight[l++]);
        if (r % 2 == 0) ans = min(ans, seg_weight[r--]);
        l /= 2; r /= 2;
    }
    return ans;
}

void dfs1(int u){
    is_leaf[u] = true;
    LL[u] = T++;
    for (int v : adj[u]) {

        is_leaf[u] = false;
        dfs1(v);
    }
    if (is_leaf[u]) {
        leaf_sum += w[u];
        w[u] = INF;
    }
    RR[u] = T - 1;
}

void dfs(int u,  ll sub){ // sub is the special subtrahend
    // if leaf
    if (is_leaf[u]){
        // just clear the node up
        is_leaf[u] = false;
        add_leaf(LL[u], 0);
        return;
    }
    // find the special node
    pair<ll, int> pr_weight = min_val(LL[u], RR[u]);
    ll min_weight = pr_weight.first;
    int index = pr_weight.second;
    special_pairs.push_back({min_weight-sub, sum_leaf(LL[u], RR[u])});
    sub = min_weight;
    // recurse all children
    for(int v:adj[index]){
        dfs(v, sub); // recurse with the special subtrahend
    }
    // clean up node
    is_leaf[index] = true;
    add_leaf(LL[index], 1);
    upd_weight(LL[index], {INF, index});
    w[index] = INF; // it's cleaned
    // rerun dfs
    dfs(u, sub);
}

void init(std::vector<signed> P, std::vector<signed> W) {
    p = vector<ll>(P.begin(), P.end());
    w = vector<ll>(W.begin(), W.end());
    n = (int)p.size();
    adj.resize(n);
    for(int i=1;i<n;++i ){
        adj[p[i]].push_back(i);
    }
    // update segment trees
    is_leaf.resize(n, false);
    LL.resize(n, 0);
    RR.resize(n, 0);
    dfs1(0);
    for (int i = 0; i < n; i++) {
        if (is_leaf[i]) {
            add_leaf(LL[i], 1);
        }
        upd_weight(LL[i], {w[i], i});
    }
    dfs(0, 0);
    Rx.resize(n+1,0);
    Lx.resize(n+1,0);
    Lx[0] = leaf_sum;
    for(auto&[wt, c]:special_pairs){
        Lx[1] += c*wt;
        Lx[c+1] -= c*wt;
        Rx[1] -= wt;
        Rx[c+1] += wt;
    }
    for(int i=1;i<=n;++i){
        Lx[i] += Lx[i-1];
        Rx[i] += Rx[i-1];
    }

}

long long query(signed L, signed R) {
   ll k = (R+L-1)/L;
   return L*Lx[k] + R*Rx[k];
}
/*
int main() {
    int N;
    assert(1 == scanf("%d", &N));
    std::vector<int> P(N);
    P[0] = -1;
    for (int i = 1; i < N; i++)
        assert(1 == scanf("%d", &P[i]));
    std::vector<int> W(N);
    for (int i = 0; i < N; i++)
        assert(1 == scanf("%d", &W[i]));
    int Q;
    assert(1 == scanf("%d", &Q));
    std::vector<int> L(Q), R(Q);
    for (int j = 0; j < Q; j++)
        assert(2 == scanf("%d%d", &L[j], &R[j]));
    fclose(stdin);

    init(P, W);
    std::vector<long long> A(Q);
    for (int j = 0; j < Q; j++)
        A[j] = query(L[j], R[j]);

    for (int j = 0; j < Q; j++)
        printf("%lld\n", A[j]);
    fclose(stdout);

    return 0;
}
*/
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...