#include <bits/stdc++.h>
#include "holiday.h"
using namespace std;
#define endl '\n'
#define ll long long
#define all(x) (x).begin(), (x).end()
const int mxn = 1e5 + 100;
struct Node {
ll act = 0, val = 0;
Node operator + (Node a) {
return {act + a.act, val * (act > 0) + a.val * (a.act > 0)};
}
};
vector<ll> a, ord, ind;
struct SGT {
vector<Node> sgt;
SGT(int sz) {
sgt.resize(4 * sz);
}
void build(int k, int l, int r) {
if (l == r) {
sgt[k].val = a[l];
return;
}
int mid = (l + r) / 2;
build(k * 2, l, mid);
build(k * 2 + 1, mid + 1, r);
sgt[k] = sgt[k * 2] + sgt[k * 2 + 1];
}
void update(int k, int l, int r, int ind, int val) {
if (l > ind || r < ind) return;
if (l == r) {
sgt[k].act += val;
return;
}
int mid = (l + r) / 2;
update(k * 2, l, mid, ind, val);
update(k * 2 + 1, mid + 1, r, ind, val);
sgt[k] = sgt[k * 2] + sgt[k * 2 + 1];
}
ll get(int k, int l, int r, int i) {
if (l > r || !i) return 0;
if (i >= sgt[k].act) return sgt[k].val * (sgt[k].act > 0);
int mid = (l + r) / 2;
if (sgt[k * 2 + 1].act > i) return get(k * 2 + 1, mid + 1, r, i);
return get(k * 2, l, mid, i - sgt[k * 2 + 1].act) + sgt[k * 2 + 1].val * (sgt[k * 2 + 1].act > 0);
}
} sgt(mxn);
int n, s;
int lp = 0, rp = -1;
ll get(int L, int R, int x) {
while (lp < L) {
sgt.update(1, 0, n - 1, ind[lp], -1);
lp++;
}
while (lp > L) {
lp--;
sgt.update(1, 0, n - 1, ind[lp], 1);
}
while (rp > R) {
sgt.update(1, 0, n - 1, ind[rp], -1);
rp--;
}
while (rp < R) {
rp++;
sgt.update(1, 0, n - 1, ind[rp], 1);
}
return sgt.get(1, 0, n - 1, x);
}
void solveL(int l, int r, int ml, int mr, int d, ll* arr) {
if (l > r) return;
int mid = (l + r) / 2;
ll mx = -1, opt = ml;
for (int i = mr; i >= ml; i--) {
if ((s - i) * d > mid) break;
ll sum = get(min(i, s - (d == 1)), s - (d == 1), mid - (s - i) * d);
if (sum > mx) {
mx = sum;
opt = i;
}
}
arr[mid] = mx;
solveL(l, mid - 1, opt, mr, d, arr);
solveL(mid + 1, r, ml, opt, d, arr);
}
void solveR(int l, int r, int ml, int mr, int d, ll *arr) {
if (l > r) return;
int mid = (l + r) / 2;
ll mx = -1, opt = ml;
for (int i = ml; i <= mr; i++) {
if ((i - s) * d > mid) break;
ll sum = get(s + (d == 1), max(s + (d == 1), i), mid - (i - s) * d);
if (sum > mx) {
mx = sum;
opt = i;
}
}
arr[mid] = mx;
solveR(l, mid - 1, ml, opt, d, arr);
solveR(mid + 1, r, opt, mr, d, arr);
}
ll findMaxAttraction(int _n, int start, int d, int attraction[]) {
n = _n;
s = start;
a.resize(n);
for (int i = 0; i < n; i++) a[i] = attraction[i], ord.push_back(i);
sort(all(ord), [&] (ll a1, ll a2) {
return a[a1] < a[a2];
});
sort(all(a));
sgt.build(1, 0, n - 1);
ind.resize(n);
for (int i = 0; i < n; i++) {
ind[ord[i]] = i;
}
ll ansl[d + 2], ansL[d + 2], ansr[d + 2], ansR[d + 2];
solveL(0, d, 0, s, 1, ansl);
solveL(0, d, 0, s, 2, ansL);
solveR(0, d, s, n - 1, 1, ansr);
solveR(0, d, s, n - 1, 2, ansR);
ll ans = max(ansl[d], ansr[d]);
for (int i = 0; i <= d; i++) {
ans = max(ans, ansl[i] + ansR[d - i]);
ans = max(ans, ansL[i] + ansr[d - i]);
}
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |