Submission #1093181

#TimeUsernameProblemLanguageResultExecution timeMemory
1093181shivansh0809Holiday (IOI14_holiday)C++17
100 / 100
282 ms48720 KiB
#include "holiday.h"
#include <bits/stdc++.h>
using namespace std;

const int N = 100'005;
const int LN = 20;

struct PersistentSegTree
{ // Returns the sum of K largest elements in [l, r)
    map<int, pair<int, int>> compress;
    // add log2(Q) is storage needed after some update query
    int L[LN * N], R[LN * N], S[LN * N];
    long long KS[LN * N];
    int NEXT_FREE_INDEX = 1;
    int roots[N];
    int MX;

    PersistentSegTree(int v[], int n) : MX(n)
    {
        int tl = 0, tr = n;
        pair<int, int> sort_v[n];
        for (int i = 0; i < n; i++)
            sort_v[i] = {v[i], i};
        sort(sort_v, sort_v + n, greater<pair<int, int>>());
        int idx[n];
        for (int i = 0; i < n; i++)
            compress[i] = sort_v[i], idx[sort_v[i].second] = i;
        roots[0] = build(tl, tr);
        for (int i = 0; i < n; i++)
            roots[i + 1] = update(roots[i], tl, tr, idx[i]);
    }
    int new_vertex(int s, int ks, int l, int r)
    {
        S[NEXT_FREE_INDEX] = s;
        KS[NEXT_FREE_INDEX] = ks;
        L[NEXT_FREE_INDEX] = l;
        R[NEXT_FREE_INDEX] = r;
        return NEXT_FREE_INDEX++;
    }

    int new_vertex(int l, int r)
    {
        S[NEXT_FREE_INDEX] = S[l] + S[r];
        KS[NEXT_FREE_INDEX] = KS[l] + KS[r];
        L[NEXT_FREE_INDEX] = l;
        R[NEXT_FREE_INDEX] = r;
        return NEXT_FREE_INDEX++;
    }

    int build(int tl, int tr)
    {
        if (tl == tr)
            return new_vertex(0, 0, -1, -1);
        int tm = (tl + tr) / 2;
        return new_vertex(build(tl, tm), build(tm + 1, tr));
    }
    int update(int idx, int tl, int tr, int pos)
    {
        if (tl == tr)
            return new_vertex(S[idx] + 1, KS[idx] + compress[pos].first, -1, -1);
        int tm = (tl + tr) / 2;
        if (pos <= tm)
            return new_vertex(update(L[idx], tl, tm, pos), R[idx]);
        else
            return new_vertex(L[idx], update(R[idx], tm + 1, tr, pos));
    }

    long long find_kth(int vl, int vr, int tl, int tr, int k)
    {
        if (tl == tr)
            return KS[vr] - KS[vl];
        int tm = (tl + tr) / 2;
        int left_count = S[L[vr]] - S[L[vl]];
        if (left_count >= k)
            return find_kth(L[vl], L[vr], tl, tm, k);
        return KS[L[vr]] - KS[L[vl]] + find_kth(R[vl], R[vr], tm + 1, tr, k - left_count);
    }

    long long find_kth(int l, int r, int k)
    {
        return find_kth(roots[l], roots[r], 0, MX, k);
    }
};

long long int findMaxAttraction(int n, int start, int d, int attraction[])
{
    long long int ans = 0;
    PersistentSegTree pst(attraction, n);
    function<long long(int, int)> C1 = [&](int k, int mid) -> long long
    {
        int days_rem = d - (mid - k + start - k);
        if (days_rem <= 0)
            return 0;
        return pst.find_kth(k, mid + 1, days_rem);
        // return mid - k;
    };
    function<long long(int, int)> C2 = [&](int mid, int k) -> long long
    {
        int days_rem = d - (k - mid + k - start);
        if (days_rem <= 0)
            return 0;
        return pst.find_kth(mid, k + 1, days_rem);
        // return k - mid;
    };

    // function<void(int, int, int, int)> dncg = [&](int l, int r, int optl, int optr)
    // {
    //     if (l > r)
    //         return;

    //     int mid = (l + r) >> 1;
    //     pair<long long, int> best = {-1, -1};

    //     for (int k = optl; k <= min(mid, optr); k++)
    //         best = max(best, {C1(k, mid), k});

    //     // debug(dp, best, ndp);
    //     // debug(ans, best, mid);
    //     ans = max(best.first, ans);
    //     int opt = best.second;

    //     dncg(l, mid - 1, optl, opt);
    //     dncg(mid + 1, r, opt, optr);
    // };
    // function<void(int, int, int, int)> dncs = [&](int l, int r, int optl, int optr)
    // {
    //     if (l > r)
    //         return;

    //     int mid = (l + r) >> 1;
    //     pair<long long, int> best = {-1, -1};

    //     for (int k = optr; k >= max(mid, optl); k--)
    //         best = max(best, {C2(mid, k), k});

    //     // debug(dp, best, ndp);
    //     // debug(ans, best, mid);
    //     ans = max(best.first, ans);
    //     int opt = best.second;

    //     dncs(l, mid - 1, optl, opt);
    //     dncs(mid + 1, r, opt, optr);
    // };
    if (d == 0)
        return 0;
    ans = attraction[start];
    vector<array<int, 4>> dncg, dncs;
    dncg.push_back({start + 1, n - 1, 0, start});
    dncs.push_back({0, start - 1, start, n - 1});
    while (!dncg.empty())
    {
        auto [l, r, optl, optr] = dncg.back();
        dncg.pop_back();
        if (l > r)
            continue;

        int mid = (l + r) >> 1;
        pair<long long, int> best = {-1, -1};

        for (int k = optl; k <= min(mid, optr); k++)
            best = max(best, {C1(k, mid), k});

        // debug(dp, best, ndp);
        // debug(ans, best, mid);
        ans = max(best.first, ans);
        int opt = best.second;

        dncg.push_back({mid + 1, r, opt, optr});
        dncg.push_back({l, mid - 1, optl, opt});
    }
    while (!dncs.empty())
    {
        auto [l, r, optl, optr] = dncs.back();
        dncs.pop_back();
        if (l > r)
            continue;

        int mid = (l + r) >> 1;
        pair<long long, int> best = {-1, -1};

        for (int k = optr; k >= max(mid, optl); k--)
            best = max(best, {C2(mid, k), k});

        // debug(dp, best, ndp);
        // debug(ans, best, mid);
        ans = max(best.first, ans);
        int opt = best.second;
        dncs.push_back({mid + 1, r, opt, optr});
        dncs.push_back({l, mid - 1, optl, opt});
    }
    // dncg(start + 1, n - 1, 0, start);
    // dncs(0, start - 1, start, n - 1);

    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...