#include "holiday.h"
#include "bits/stdc++.h"
#define fi first
#define sc second
using namespace std;
using ll = long long;
using pll = pair<ll, ll>;
const ll maxn = 300005;
ll n;
ll a[maxn];
ll d;
ll st;
ll opt[maxn];
pll t[2 * maxn];
ll ls[2 * maxn], rs[2 * maxn], tsz = 0, root = 0;
ll id[maxn];
pll b[maxn];
void init(ll &v, ll tl, ll tr)
{
if (!v)
v = ++tsz;
if (tl == tr)
{
t[v] = {0, 0};
return;
}
ll mid = (tl + tr) / 2;
init(ls[v], tl, mid);
init(rs[v], mid + 1, tr);
t[v] = {0, 0};
}
ll get(ll v, ll tl, ll tr, ll k)
{
if (k == 0)
return 0;
if (k >= t[v].sc)
return t[v].fi;
ll mid = (tl + tr) / 2;
if (k >= t[rs[v]].sc)
return t[rs[v]].fi + get(ls[v], tl, mid, k - t[rs[v]].sc);
return get(rs[v], mid + 1, tr, k);
}
void upd(ll v, ll tl, ll tr, ll i, bool e)
{
if (v == root)
{
i = id[i];
}
if (tl == tr)
{
if (e)
t[v] = {b[i].fi, 1};
else
t[v] = {0, 0};
return;
}
ll mid = (tl + tr) / 2;
if (i <= mid)
upd(ls[v], tl, mid, i, e);
else
upd(rs[v], mid + 1, tr, i, e);
t[v] = {t[ls[v]].fi + t[rs[v]].fi, t[ls[v]].sc + t[rs[v]].sc};
}
ll cr;
ll dpr[maxn];
ll dpl[maxn];
bool bf = 0;
void dp(ll l, ll r, ll tl, ll tr)
{
if (l > r)
return;
ll mid = (l + r) / 2;
if (bf)
tl = st, tr = n;
while (cr > tl)
{
upd(root, 1, n, cr, 0);
cr--;
}
while (cr < tl)
{
upd(root, 1, n, cr + 1, 1);
cr++;
}
opt[mid] = -1;
dpr[mid] = -1;
for (ll i = tl; i <= tr; i++)
{
ll ck = mid - 2 * (i - st);
if (ck < 0)
break;
ll cur = get(root, 1, n, ck);
if (cur > dpr[mid])
{
dpr[mid] = cur;
opt[mid] = i;
}
if (i < n)
{
upd(root, 1, n, i + 1, 1);
cr = i + 1;
}
}
if (opt[mid] != -1)
{
dp(l, mid - 1, tl, opt[mid]);
dp(mid + 1, r, opt[mid], tr);
}
else
{
dp(l, mid - 1, tl, tr);
dp(mid + 1, r, tl, tr);
}
dpr[mid] = max(dpr[mid], 0LL);
}
void dp2(ll l, ll r, ll tl, ll tr)
{
if (l > r)
return;
ll mid = (l + r) / 2;
if (bf)
tl = 1, tr = st - 1;
while (cr < tr)
{
upd(root, 1, n, cr, 0);
cr++;
}
while (cr > tr)
{
upd(root, 1, n, cr - 1, 1);
cr--;
}
opt[mid] = -1;
dpl[mid] = -1;
for (ll i = tr; i >= tl; i--)
{
ll ck = mid - (st - i);
if (ck < 0)
continue;
ll cur = get(root, 1, n, ck);
if (cur > dpl[mid])
{
dpl[mid] = cur;
opt[mid] = i;
}
if (i > 1)
{
upd(root, 1, n, i - 1, 1);
cr = i - 1;
}
}
if (opt[mid] != -1)
{
dp2(mid + 1, r, tl, opt[mid]);
dp2(l, mid - 1, opt[mid], tr);
}
else
{
dp2(mid + 1, r, tl, tr);
dp2(l, mid - 1, tl, tr);
}
dpl[mid] = max(dpl[mid], 0LL);
}
ll reshi()
{
for (ll i = 1; i <= n; i++)
b[i] = {a[i], i};
sort(b + 1, b + 1 + n);
for (ll i = 1; i <= n; i++)
id[b[i].sc] = i;
init(root, 1, n);
cr = st;
upd(root, 1, n, st, 1);
for (ll i = 0; i <= d; i++)
dpr[i] = dpl[i] = opt[i] = 0;
dp(1, d, st, n);
for (ll i = 0; i <= d; i++)
dpl[i] = opt[i] = 0;
init(root, 1, n);
cr = st - 1;
if (cr)
upd(root, 1, n, cr, 1);
dp2(2, d, 1, st - 1);
ll ans = 0;
dpl[0] = dpl[1] = dpr[0] = 0;
for (ll i = 0; i <= d; i++)
ans = max(ans, dpl[i] + dpr[d - i]);
return ans;
}
ll findMaxAttraction(int N, int str, int D, int attraction[])
{
n = N, d = D, st = str + 1;
for (ll i = 1; i <= n; i++)
a[i] = attraction[i - 1];
ll ans = reshi();
reverse(a + 1, a + 1 + n);
st = n - st + 1;
ans = max(ans, reshi());
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |