Submission #1158863

#TimeUsernameProblemLanguageResultExecution timeMemory
1158863hungntHoliday (IOI14_holiday)C++20
100 / 100
930 ms16116 KiB
#include "holiday.h"
#include "bits/stdc++.h"
#define fi first
#define sc second
using namespace std;

using ll = long long;
using pll = pair<ll, ll>;
const ll maxn = 300005;
ll n;
ll a[maxn];
ll d;
ll st;
ll opt[maxn];
pll t[2 * maxn];
ll ls[2 * maxn], rs[2 * maxn], tsz = 0, root = 0;
ll id[maxn];
pll b[maxn];

void init(ll &v, ll tl, ll tr)
{
    if (!v)
        v = ++tsz;
    if (tl == tr)
    {
        t[v] = {0, 0};
        return;
    }
    ll mid = (tl + tr) / 2;
    init(ls[v], tl, mid);
    init(rs[v], mid + 1, tr);
    t[v] = {0, 0};
}

ll get(ll v, ll tl, ll tr, ll k)
{
    if (k == 0)
        return 0;
    if (k >= t[v].sc)
        return t[v].fi;
    ll mid = (tl + tr) / 2;
    if (k >= t[rs[v]].sc)
        return t[rs[v]].fi + get(ls[v], tl, mid, k - t[rs[v]].sc);
    return get(rs[v], mid + 1, tr, k);
}

void upd(ll v, ll tl, ll tr, ll i, bool e)
{
    if (v == root)
    {
        i = id[i];
    }
    if (tl == tr)
    {
        if (e)
            t[v] = {b[i].fi, 1};
        else
            t[v] = {0, 0};
        return;
    }
    ll mid = (tl + tr) / 2;
    if (i <= mid)
        upd(ls[v], tl, mid, i, e);
    else
        upd(rs[v], mid + 1, tr, i, e);
    t[v] = {t[ls[v]].fi + t[rs[v]].fi, t[ls[v]].sc + t[rs[v]].sc};
}

ll cr;
ll dpr[maxn];
ll dpl[maxn];
bool bf = 0;

void dp(ll l, ll r, ll tl, ll tr)
{
    if (l > r)
        return;
    ll mid = (l + r) / 2;
    if (bf)
        tl = st, tr = n;
    while (cr > tl)
    {
        upd(root, 1, n, cr, 0);
        cr--;
    }
    while (cr < tl)
    {
        upd(root, 1, n, cr + 1, 1);
        cr++;
    }
    opt[mid] = -1;
    dpr[mid] = -1;
    for (ll i = tl; i <= tr; i++)
    {
        ll ck = mid - 2 * (i - st);
        if (ck < 0)
            break;
        ll cur = get(root, 1, n, ck);
        if (cur > dpr[mid])
        {
            dpr[mid] = cur;
            opt[mid] = i;
        }
        if (i < n)
        {
            upd(root, 1, n, i + 1, 1);
            cr = i + 1;
        }
    }
    if (opt[mid] != -1)
    {
        dp(l, mid - 1, tl, opt[mid]);
        dp(mid + 1, r, opt[mid], tr);
    }
    else
    {
        dp(l, mid - 1, tl, tr);
        dp(mid + 1, r, tl, tr);
    }
    dpr[mid] = max(dpr[mid], 0LL);
}

void dp2(ll l, ll r, ll tl, ll tr)
{
    if (l > r)
        return;
    ll mid = (l + r) / 2;
    if (bf)
        tl = 1, tr = st - 1;
    while (cr < tr)
    {
        upd(root, 1, n, cr, 0);
        cr++;
    }
    while (cr > tr)
    {
        upd(root, 1, n, cr - 1, 1);
        cr--;
    }
    opt[mid] = -1;
    dpl[mid] = -1;
    for (ll i = tr; i >= tl; i--)
    {
        ll ck = mid - (st - i);
        if (ck < 0)
            continue;
        ll cur = get(root, 1, n, ck);
        if (cur > dpl[mid])
        {
            dpl[mid] = cur;
            opt[mid] = i;
        }
        if (i > 1)
        {
            upd(root, 1, n, i - 1, 1);
            cr = i - 1;
        }
    }
    if (opt[mid] != -1)
    {
        dp2(mid + 1, r, tl, opt[mid]);
        dp2(l, mid - 1, opt[mid], tr);
    }
    else
    {
        dp2(mid + 1, r, tl, tr);
        dp2(l, mid - 1, tl, tr);
    }
    dpl[mid] = max(dpl[mid], 0LL);
}

ll reshi()
{
    for (ll i = 1; i <= n; i++)
        b[i] = {a[i], i};
    sort(b + 1, b + 1 + n);
    for (ll i = 1; i <= n; i++)
        id[b[i].sc] = i;
    init(root, 1, n);
    cr = st;
    upd(root, 1, n, st, 1);
    for (ll i = 0; i <= d; i++)
        dpr[i] = dpl[i] = opt[i] = 0;
    dp(1, d, st, n);
    for (ll i = 0; i <= d; i++)
        dpl[i] = opt[i] = 0;
    init(root, 1, n);
    cr = st - 1;
    if (cr)
        upd(root, 1, n, cr, 1);
    dp2(2, d, 1, st - 1);
    ll ans = 0;
    dpl[0] = dpl[1] = dpr[0] = 0;
    for (ll i = 0; i <= d; i++)
        ans = max(ans, dpl[i] + dpr[d - i]);
    return ans;
}

ll findMaxAttraction(int N, int str, int D, int attraction[])
{
    n = N, d = D, st = str + 1;
    for (ll i = 1; i <= n; i++)
        a[i] = attraction[i - 1];
    ll ans = reshi();
    reverse(a + 1, a + 1 + n);
    st = n - st + 1;
    ans = max(ans, reshi());
    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...