Submission #1360250

#TimeUsernameProblemLanguageResultExecution timeMemory
1360250ws555Tree (IOI24_tree)C++20
100 / 100
87 ms26396 KiB
#include "tree.h"
#include <bits/stdc++.h>
using namespace std;

static long long leafCost;
static long long baseDropCost;
static int M;

static vector<long long> gain;
static vector<long long> prefGain;

void init(vector<int> P, vector<int> W) {
    int N = (int)P.size();

    vector<vector<int>> adj(N);
    vector<int> deg(N, 0);

    for (int i = 1; i < N; i++) {
        deg[P[i]]++;
        adj[i].push_back(P[i]);
        adj[P[i]].push_back(i);
    }

    leafCost = 0;
    baseDropCost = 0;

    vector<int> tokenCount(N, 0);

    for (int v = 0; v < N; v++) {
        if (deg[v] == 0) {
            leafCost += W[v];
        } else {
            tokenCount[v] = deg[v] - 1;
            baseDropCost += 1LL * tokenCount[v] * W[v];
        }
    }

    M = 0;
    for (int x : tokenCount) M += x;

    gain.assign(M + 2, 0);
    prefGain.assign(M + 2, 0);

    if (M == 0) return;

    vector<int> order(N);
    iota(order.begin(), order.end(), 0);

    sort(order.begin(), order.end(), [&](int a, int b) {
        return W[a] > W[b];
    });

    vector<int> parent(N), compTokens(N, 0);
    vector<long long> startWeight(N, 0);
    vector<char> active(N, false);

    iota(parent.begin(), parent.end(), 0);

    vector<long long> diff(M + 3, 0);

    function<int(int)> findRoot = [&](int x) {
        if (parent[x] == x) return x;
        return parent[x] = findRoot(parent[x]);
    };

    auto closeComponent = [&](int r, long long currentWeight) {
        long long len = startWeight[r] - currentWeight;
        int c = compTokens[r];

        if (c > 0 && len > 0) {
            diff[1] += len;
            diff[c + 1] -= len;
        }
    };

    auto unite = [&](int a, int b, long long currentWeight) {
        int ra = findRoot(a);
        int rb = findRoot(b);

        if (ra == rb) return;

        closeComponent(ra, currentWeight);
        closeComponent(rb, currentWeight);

        if (compTokens[ra] < compTokens[rb]) swap(ra, rb);

        parent[rb] = ra;
        compTokens[ra] += compTokens[rb];
        startWeight[ra] = currentWeight;
    };

    for (int v : order) {
        active[v] = true;
        parent[v] = v;
        compTokens[v] = tokenCount[v];
        startWeight[v] = W[v];

        for (int u : adj[v]) {
            if (active[u]) {
                unite(v, u, W[v]);
            }
        }
    }

    for (int v = 0; v < N; v++) {
        int r = findRoot(v);
        if (r == v) {
            closeComponent(r, 0);
        }
    }

    long long cur = 0;
    for (int j = 1; j <= M; j++) {
        cur += diff[j];
        gain[j] = cur;
        prefGain[j] = prefGain[j - 1] + gain[j];
    }
}

long long query(int L, int R) {
    if (M == 0) {
        return 1LL * L * leafCost;
    }

    long long D = R - L;
    long long whole = D / L;
    long long rem = D % L;

    if (whole >= M) {
        return 1LL * L * leafCost;
    }

    long long normalizedBase = leafCost + baseDropCost;

    long long ans =
        1LL * L * (normalizedBase - prefGain[whole])
        - 1LL * rem * gain[whole + 1];

    return ans;
}
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...