Submission #1132287

#TimeUsernameProblemLanguageResultExecution timeMemory
1132287alterioHoliday (IOI14_holiday)C++20
100 / 100
806 ms17340 KiB
#include <bits/stdc++.h>
#include "holiday.h"

using namespace std;

#define endl '\n'
#define ll long long 
#define all(x) (x).begin(), (x).end()

const int mxn = 1e5 + 100;

struct Node {
    ll act = 0, val = 0;
    
    Node operator + (Node a) {
        return {act + a.act, val * (act > 0) + a.val * (a.act > 0)};
    }
};

vector<ll> a, ord, ind;

struct SGT {
    vector<Node> sgt;

    SGT(int sz) {
        sgt.resize(4 * sz);
    }

    void build(int k, int l, int r) {
        if (l == r) {
            sgt[k].val = a[l];
            return;
        }
        int mid = (l + r) / 2;
        build(k * 2, l, mid);
        build(k * 2 + 1, mid + 1, r);
        sgt[k] = sgt[k * 2] + sgt[k * 2 + 1];
    }
    
    void update(int k, int l, int r, int ind, int val) {
        if (l > ind || r < ind) return;
        if (l == r) {
            sgt[k].act += val;
            return;
        }
        int mid = (l + r) / 2;
        update(k * 2, l, mid, ind, val);
        update(k * 2 + 1, mid + 1, r, ind, val);
        sgt[k] = sgt[k * 2] + sgt[k * 2 + 1];
    }

    ll get(int k, int l, int r, int i) {
        if (l > r || !i) return 0;
        if (i >= sgt[k].act) return sgt[k].val * (sgt[k].act > 0);
        int mid = (l + r) / 2;
        if (sgt[k * 2 + 1].act > i) return get(k * 2 + 1, mid + 1, r, i);
        return get(k * 2, l, mid, i - sgt[k * 2 + 1].act) + sgt[k * 2 + 1].val * (sgt[k * 2 + 1].act > 0);
    }
} sgt(mxn);

int n, s;
int lp = 0, rp = -1;

ll get(int L, int R, int x) {
    while (lp < L) {
        sgt.update(1, 0, n - 1, ind[lp], -1);
        lp++;
    }
    while (lp > L) {
        lp--;
        sgt.update(1, 0, n - 1, ind[lp], 1);
    }
    while (rp > R) {
        sgt.update(1, 0, n - 1, ind[rp], -1);
        rp--;
    }
    while (rp < R) {
        rp++;
        sgt.update(1, 0, n - 1, ind[rp], 1);
    }
    return sgt.get(1, 0, n - 1, x);
}

void solveL(int l, int r, int ml, int mr, int d, ll* arr) {
    if (l > r) return;
    int mid = (l + r) / 2;
    ll mx = -1, opt = ml;
    for (int i = mr; i >= ml; i--) {
        if ((s - i) * d > mid) break;
        ll sum = get(min(i, s - (d == 1)), s - (d == 1), mid - (s - i) * d);
        if (sum > mx) {
            mx = sum;
            opt = i;
        }
    }
    arr[mid] = mx;
    solveL(l, mid - 1, opt, mr, d, arr);
    solveL(mid + 1, r, ml, opt, d, arr);
}

void solveR(int l, int r, int ml, int mr, int d, ll *arr) {
    if (l > r) return;
    int mid = (l + r) / 2;
    ll mx = -1, opt = ml;
    for (int i = ml; i <= mr; i++) {
        if ((i - s) * d > mid) break;
        ll sum = get(s + (d == 1), max(s + (d == 1), i), mid - (i - s) * d);
        if (sum > mx) {
            mx = sum;
            opt = i;
        }
    }
    arr[mid] = mx;
    solveR(l, mid - 1, ml, opt, d, arr);
    solveR(mid + 1, r, opt, mr, d, arr);
}

ll findMaxAttraction(int _n, int start, int d, int attraction[]) {
    n = _n;
    s = start;
    a.resize(n);
    for (int i = 0; i < n; i++) a[i] = attraction[i], ord.push_back(i);
    sort(all(ord), [&] (ll a1, ll a2) {
        return a[a1] < a[a2];
    });
    sort(all(a));
    sgt.build(1, 0, n - 1);
    ind.resize(n);
    for (int i = 0; i < n; i++) {
        ind[ord[i]] = i;
    }
    ll ansl[d + 2], ansL[d + 2], ansr[d + 2], ansR[d + 2];
    solveL(0, d, 0, s, 1, ansl);
    solveL(0, d, 0, s, 2, ansL);
    solveR(0, d, s, n - 1, 1, ansr);
    solveR(0, d, s, n - 1, 2, ansR);
    ll ans = max(ansl[d], ansr[d]);
    for (int i = 0; i <= d; i++) {
        ans = max(ans, ansl[i] + ansR[d - i]);
        ans = max(ans, ansL[i] + ansr[d - i]);
    }
    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...