Submission #1354282

#TimeUsernameProblemLanguageResultExecution timeMemory
1354282tickcrossyTree (IOI24_tree)C++20
0 / 100
2092 ms204552 KiB
#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>

using namespace std;

const int MAXN = 200005;
const int MAX_NODES = 8000000;

struct Node {
    int lc, rc;
    double sum, sum_x;
} tr[MAX_NODES];

int node_cnt = 0;
int root[MAXN];
double val_node[MAXN];
int d_deg[MAXN];

int N, M;
vector<int> children[MAXN];
long long W[MAXN];
long long X[MAXN * 2];

int get_idx(long long val) {
    return lower_bound(X + 1, X + M + 1, val) - X;
}

int merge_tree(int p, int q, int l, int r) {
    if (!p || !q) return p ? p : q;
    int rt = ++node_cnt;
    tr[rt].sum = tr[p].sum + tr[q].sum;
    tr[rt].sum_x = tr[p].sum_x + tr[q].sum_x;
    if (l == r) return rt;
    int mid = l + (r - l) / 2;
    tr[rt].lc = merge_tree(tr[p].lc, tr[q].lc, l, mid);
    tr[rt].rc = merge_tree(tr[p].rc, tr[q].rc, mid + 1, r);
    return rt;
}

void clear_range(int& p, int l, int r, int ql, int qr) {
    if (!p || ql > r || qr < l) return;
    if (ql <= l && r <= qr) { p = 0; return; }
    int rt = ++node_cnt;
    tr[rt] = tr[p];
    p = rt;
    int mid = l + (r - l) / 2;
    clear_range(tr[p].lc, l, mid, ql, qr);
    clear_range(tr[p].rc, mid + 1, r, ql, qr);
    tr[p].sum = (tr[p].lc ? tr[tr[p].lc].sum : 0) + (tr[p].rc ? tr[tr[p].rc].sum : 0);
    tr[p].sum_x = (tr[p].lc ? tr[tr[p].lc].sum_x : 0) + (tr[p].rc ? tr[tr[p].rc].sum_x : 0);
}

void add_mass(int& p, int l, int r, int idx, double val) {
    if (!p) { p = ++node_cnt; tr[p] = {0, 0, 0, 0}; }
    else { int rt = ++node_cnt; tr[rt] = tr[p]; p = rt; }
    tr[p].sum += val;
    tr[p].sum_x += val * X[idx];
    if (l == r) return;
    int mid = l + (r - l) / 2;
    if (idx <= mid) add_mass(tr[p].lc, l, mid, idx, val);
    else add_mass(tr[p].rc, mid + 1, r, idx, val);
}

void set_mass(int& p, int l, int r, int idx, double val) {
    if (!p) { p = ++node_cnt; tr[p] = {0, 0, 0, 0}; }
    else { int rt = ++node_cnt; tr[rt] = tr[p]; p = rt; }
    if (l == r) {
        tr[p].sum = val;
        tr[p].sum_x = val * X[idx];
        return;
    }
    int mid = l + (r - l) / 2;
    if (idx <= mid) set_mass(tr[p].lc, l, mid, idx, val);
    else set_mass(tr[p].rc, mid + 1, r, idx, val);
    tr[p].sum = (tr[p].lc ? tr[tr[p].lc].sum : 0) + (tr[p].rc ? tr[tr[p].rc].sum : 0);
    tr[p].sum_x = (tr[p].lc ? tr[tr[p].lc].sum_x : 0) + (tr[p].rc ? tr[tr[p].rc].sum_x : 0);
}

double query_sum(int p, int l, int r, int ql, int qr) {
    if (!p || ql > r || qr < l) return 0.0;
    if (ql <= l && r <= qr) return tr[p].sum;
    int mid = l + (r - l) / 2;
    return query_sum(tr[p].lc, l, mid, ql, qr) + query_sum(tr[p].rc, mid + 1, r, ql, qr);
}

double query_sum_x(int p, int l, int r, int ql, int qr) {
    if (!p || ql > r || qr < l) return 0.0;
    if (ql <= l && r <= qr) return tr[p].sum_x;
    int mid = l + (r - l) / 2;
    return query_sum_x(tr[p].lc, l, mid, ql, qr) + query_sum_x(tr[p].rc, mid + 1, r, ql, qr);
}

int find_first(int p, int l, int r, double& current_sum, double target) {
    if (!p) return M + 1;
    if (l == r) {
        current_sum += tr[p].sum;
        if (current_sum > target + 1e-9) return l;
        return M + 1;
    }
    int mid = l + (r - l) / 2;
    double left_sum = tr[p].lc ? tr[tr[p].lc].sum : 0;
    if (current_sum + left_sum > target + 1e-9) return find_first(tr[p].lc, l, mid, current_sum, target);
    current_sum += left_sum;
    return find_first(tr[p].rc, mid + 1, r, current_sum, target);
}

void init(std::vector<int> P, std::vector<int> w_in) {
    N = w_in.size();
    for (int i = 0; i < N; i++) W[i] = w_in[i];
    for (int i = 1; i < N; i++) children[P[i]].push_back(i);
    vector<long long> x_vals;
    x_vals.push_back(0);
    for (int i = 0; i < N; i++) {
        x_vals.push_back(-W[i]);
        x_vals.push_back(W[i]);
    }
    sort(x_vals.begin(), x_vals.end());
    x_vals.erase(unique(x_vals.begin(), x_vals.end()), x_vals.end());
    M = x_vals.size();
    for (int i = 0; i < M; i++) X[i + 1] = x_vals[i];
}

long long query(int L, int R) {
    double c = R - L;
    node_cnt = 0;
    int idx_0 = get_idx(0);

    for (int v = N - 1; v >= 0; v--) {
        root[v] = 0;
        d_deg[v] = children[v].size();
        double val_sum = 0;
        
        for (int u : children[v]) {
            root[v] = merge_tree(root[v], root[u], 1, M);
            val_sum += val_node[u];
        }
        
        auto eval = [&](int p_idx) -> double {
            if (p_idx < 1 || p_idx > M) return -1e18;
            double p = X[p_idx];
            double integral_M = 0;
            double M_p = query_sum(root[v], 1, M, 1, p_idx);
            if (p > 0) integral_M = p * M_p - query_sum_x(root[v], 1, M, idx_0 + 1, p_idx);
            else if (p < 0) integral_M = p * M_p - query_sum_x(root[v], 1, M, p_idx + 1, idx_0);
            double ans = val_sum + p * (1.0 * R * d_deg[v] - L) - integral_M;
            if (p > 0) ans -= c * p;
            return ans;
        };

        double target_L = 1.0 * R * (d_deg[v] - 1);
        double target_R = 1.0 * R * d_deg[v] - L;
        
        double cur_sum = 0;
        int p_L = find_first(root[v], 1, M, cur_sum, target_L + 1e-9);
        cur_sum = 0;
        int p_R = find_first(root[v], 1, M, cur_sum, target_R - 1e-9);

        double ans_v = eval(get_idx(-W[v]));
        ans_v = max(ans_v, eval(get_idx(W[v])));
        ans_v = max(ans_v, eval(idx_0));
        if (p_L >= 1 && p_L <= M && X[p_L] >= -W[v] && X[p_L] <= W[v]) ans_v = max(ans_v, eval(p_L));
        if (p_R >= 1 && p_R <= M && X[p_R] >= -W[v] && X[p_R] <= W[v]) ans_v = max(ans_v, eval(p_R));
        val_node[v] = ans_v;

        if (p_L <= M) {
            double m_L_old = query_sum(root[v], 1, M, 1, p_L);
            double m_R_prev = query_sum(root[v], 1, M, 1, p_R - 1);
            clear_range(root[v], 1, M, 1, p_L - 1);
            clear_range(root[v], 1, M, p_R + 1, M);
            if (p_L == p_R) {
                set_mass(root[v], 1, M, p_L, c);
            } else {
                set_mass(root[v], 1, M, p_L, m_L_old - target_L);
                if (p_R <= M) set_mass(root[v], 1, M, p_R, target_R - m_R_prev);
            }
        } else {
            clear_range(root[v], 1, M, 1, M);
        }

        double total_mass = query_sum(root[v], 1, M, 1, M);
        if (total_mass < c - 1e-9) add_mass(root[v], 1, M, M, c - total_mass);

        int idx_minus = get_idx(-W[v]);
        int idx_plus = get_idx(W[v]);
        
        double left_del = query_sum(root[v], 1, M, 1, idx_minus - 1);
        clear_range(root[v], 1, M, 1, idx_minus - 1);
        if (left_del > 1e-9) add_mass(root[v], 1, M, idx_minus, left_del);
        
        double right_del = query_sum(root[v], 1, M, idx_plus + 1, M);
        clear_range(root[v], 1, M, idx_plus + 1, M);
        if (right_del > 1e-9) add_mass(root[v], 1, M, idx_plus, right_del);
    }
    return round(val_node[0]);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...