제출 #1354275

#제출 시각아이디문제언어결과실행 시간메모리
1354275tickcrossy트리 (IOI24_tree)C++20
0 / 100
2096 ms93296 KiB
#include <bits/stdc++.h>
using namespace std;

const int MAXN = 200005;

struct Node {
    int lc, rc;
    double sum, sum_x;
} tr[6000000];

int node_cnt = 0;
int root[MAXN];
double init_slope[MAXN];
double f0_val[MAXN];

int N, M;
vector<int> children[MAXN];
long long W[MAXN];
long long X[MAXN * 2];

void update(int& p, int l, int r, int idx, double val) {
    if (!p) {
        p = ++node_cnt;
        tr[p] = {0, 0, 0, 0};
    }
    tr[p].sum += val;
    tr[p].sum_x += val * X[idx];
    if (l == r) return;
    int mid = l + (r - l) / 2;
    if (idx <= mid) update(tr[p].lc, l, mid, idx, val);
    else update(tr[p].rc, mid + 1, r, idx, val);
}

void clear_range(int& p, int l, int r, int ql, int qr) {
    if (!p || ql > r || qr < l) return;
    if (ql <= l && r <= qr) {
        p = 0;
        return;
    }
    int mid = l + (r - l) / 2;
    clear_range(tr[p].lc, l, mid, ql, qr);
    clear_range(tr[p].rc, mid + 1, r, ql, qr);
    tr[p].sum = (tr[p].lc ? tr[tr[p].lc].sum : 0) + (tr[p].rc ? tr[tr[p].rc].sum : 0);
    tr[p].sum_x = (tr[p].lc ? tr[tr[p].lc].sum_x : 0) + (tr[p].rc ? tr[tr[p].rc].sum_x : 0);
}

double query_sum(int p, int l, int r, int ql, int qr) {
    if (!p || ql > r || qr < l) return 0.0;
    if (ql <= l && r <= qr) return tr[p].sum;
    int mid = l + (r - l) / 2;
    return query_sum(tr[p].lc, l, mid, ql, qr) + query_sum(tr[p].rc, mid + 1, r, ql, qr);
}

double query_sum_x(int p, int l, int r, int ql, int qr) {
    if (!p || ql > r || qr < l) return 0.0;
    if (ql <= l && r <= qr) return tr[p].sum_x;
    int mid = l + (r - l) / 2;
    return query_sum_x(tr[p].lc, l, mid, ql, qr) + query_sum_x(tr[p].rc, mid + 1, r, ql, qr);
}

int find_first(int p, int l, int r, double& current_sum, double threshold) {
    if (!p) {
        if (current_sum < threshold - 1e-9) return l;
        return M + 1;
    }
    if (l == r) {
        current_sum += tr[p].sum;
        if (current_sum < threshold - 1e-9) return l;
        return M + 1;
    }
    int mid = l + (r - l) / 2;
    double left_sum = tr[p].lc ? tr[tr[p].lc].sum : 0;
    if (current_sum + left_sum < threshold - 1e-9) {
        return find_first(tr[p].lc, l, mid, current_sum, threshold);
    } else {
        current_sum += left_sum;
        return find_first(tr[p].rc, mid + 1, r, current_sum, threshold);
    }
}

int merge_trees(int p, int q, int l, int r) {
    if (!p || !q) return p ? p : q;
    int rt = ++node_cnt;
    tr[rt] = {0, 0, tr[p].sum + tr[q].sum, tr[p].sum_x + tr[q].sum_x};
    if (l == r) return rt;
    int mid = l + (r - l) / 2;
    tr[rt].lc = merge_trees(tr[p].lc, tr[q].lc, l, mid);
    tr[rt].rc = merge_trees(tr[p].rc, tr[q].rc, mid + 1, r);
    return rt;
}

int get_idx(long long val) {
    return lower_bound(X + 1, X + M + 1, val) - X;
}

void init(std::vector<int> P, std::vector<int> w_in) {
    N = w_in.size();
    for (int i = 0; i < N; i++) W[i] = w_in[i];
    for (int i = 1; i < N; i++) children[P[i]].push_back(i);
    vector<long long> x_vals;
    x_vals.push_back(0);
    for (int i = 0; i < N; i++) {
        x_vals.push_back(-W[i]);
        x_vals.push_back(W[i]);
    }
    sort(x_vals.begin(), x_vals.end());
    x_vals.erase(unique(x_vals.begin(), x_vals.end()), x_vals.end());
    M = x_vals.size();
    for (int i = 0; i < M; i++) X[i + 1] = x_vals[i];
}

double eval_E(int u, long long Y, double c, int idx_0) {
    if (Y == 0) return f0_val[u];
    int idx_Y = get_idx(Y);
    if (Y > 0) {
        double S_0 = init_slope[u] + query_sum(root[u], 1, M, 1, idx_0);
        return f0_val[u] + Y * S_0 + Y * query_sum(root[u], 1, M, idx_0 + 1, idx_Y) 
               - query_sum_x(root[u], 1, M, idx_0 + 1, idx_Y);
    } else {
        double S_Y = init_slope[u] + query_sum(root[u], 1, M, 1, idx_Y);
        double integral = abs(Y) * S_Y - query_sum_x(root[u], 1, M, idx_Y + 1, idx_0);
        return f0_val[u] - abs(Y) * c - integral;
    }
}

long long query(int L, int R) {
    double c = 1.0 - (double)L / R;
    node_cnt = 0;
    int idx_0 = get_idx(0);
    
    for (int u = N - 1; u >= 0; u--) {
        root[u] = 0;
        init_slope[u] = children[u].size() - 1;
        f0_val[u] = 0;
        
        for (int v : children[u]) {
            int idx_minus_w = get_idx(-W[v]);
            int idx_plus_w = get_idx(W[v]);
            
            init_slope[v] += query_sum(root[v], 1, M, 1, idx_minus_w - 1);
            clear_range(root[v], 1, M, 1, idx_minus_w - 1);
            clear_range(root[v], 1, M, idx_plus_w + 1, M);
            
            double tmp_sum = init_slope[v];
            int i_0 = find_first(root[v], 1, M, tmp_sum, 0.0);
            if (i_0 <= M) {
                double m_curr = tmp_sum;
                double m_prev = m_curr - (tr[root[v]].lc || tr[root[v]].rc ? query_sum(root[v], 1, M, i_0, i_0) : 0);
                clear_range(root[v], 1, M, 1, i_0 - 1);
                update(root[v], 1, M, i_0, m_prev);
                init_slope[v] = 0;
            } else {
                init_slope[v] = 0;
            }
            
            tmp_sum = init_slope[v];
            int i_c = find_first(root[v], 1, M, tmp_sum, -c);
            if (i_c <= M) {
                double m_curr = tmp_sum;
                double m_prev = m_curr - query_sum(root[v], 1, M, i_c, i_c);
                update(root[v], 1, M, i_c, -c - m_prev);
                clear_range(root[v], 1, M, i_c + 1, M);
            }
            
            if (!root[v]) init_slope[v] = max(-c, init_slope[v]);
            
            double m_curr = init_slope[v];
            init_slope[v] = 0;
            if (m_curr < -1e-9) update(root[v], 1, M, idx_minus_w, m_curr);
            
            double e_slope = init_slope[v] + (root[v] ? tr[root[v]].sum : 0);
            if (-c - e_slope < -1e-9) update(root[v], 1, M, idx_plus_w, -c - e_slope);
            
            root[u] = merge_trees(root[u], root[v], 1, M);
            init_slope[u] += init_slope[v];
            f0_val[u] += f0_val[v]; // add up the `ans(v)`
        }
        
        double ans_u = eval_E(u, 0, c, idx_0);
        ans_u = max(ans_u, eval_E(u, -W[u], c, idx_0));
        ans_u = max(ans_u, eval_E(u, W[u], c, idx_0));
        
        double tmp_sum = init_slope[u];
        int i_left = find_first(root[u], 1, M, tmp_sum, -c);
        if (i_left <= M && X[i_left] < 0 && X[i_left] >= -W[u]) 
            ans_u = max(ans_u, eval_E(u, X[i_left], c, idx_0));
            
        tmp_sum = init_slope[u];
        int i_right = find_first(root[u], 1, M, tmp_sum, 0.0);
        if (i_right <= M && X[i_right] > 0 && X[i_right] <= W[u])
            ans_u = max(ans_u, eval_E(u, X[i_right], c, idx_0));
            
        f0_val[u] = ans_u; // Now `ans(u)` replaces f0
    }
    
    return round(f0_val[0] * R);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...