#include <bits/stdc++.h>
#include "holiday.h"
#define F first
#define S second
#define MP make_pair
#define SZ(x) ((int)(x).size())
#define ALL(x) x.begin(), x.end()
using namespace std;
using ll = long long;
using vi = vector<int>;
using vl = vector<ll>;
using vvi = vector<vi>;
using vvl = vector<vl>;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
using vii = vector<pii>;
using vll = vector<pll>;
// Quick sketch: the optimal path either goes right first then left, or left first then right.
// The second case reduces to the first by reversing the array, so we only need to handle
// the "right first" case.
//
// We enumerate over the rightmost city visited (determining how far right we go) and
// the leftmost city visited (determining how far left we go). For each choice of
// rightmost city, there exists an optimal leftmost city. After this, subtract the total distance
// traveled from the number of days, call it k, and then choose the top k cities with most attractions -> Persistent Segment Tree
//
// Notice that as the rightmost city increases,
// the corresponding optimal leftmost city also increases -> handled with d&c DP -> O(nlog²n)
static const int maxn = 1e5 + 5;
static const int SIZE = 3e6 + 10;
struct SegTree {
struct Node {
int lc, rc;
ll sum; int cnt;
Node(): lc(0), rc(0), sum(0), cnt(0) {}
};
int node_cnt = 0;
array<Node, SIZE> node;
int create_node() {
return ++node_cnt;
}
int create_node(Node nd) {
node[++node_cnt] = nd;
return node_cnt;
}
Node merge(Node a, Node b) {
Node res;
res.sum = a.sum + b.sum;
res.cnt = a.cnt + b.cnt;
return res;
}
void pull(int id) {
auto res = merge(node[node[id].lc], node[node[id].rc]);
node[id].sum = res.sum;
node[id].cnt = res.cnt;
}
int upd(int u, int l, int r, int pos, int val) {
int nu = create_node(node[u]);
if (l == r) {
node[nu].cnt++;
node[nu].sum += val;
return nu;
}
int mid = l + r >> 1;
if (pos <= mid) node[nu].lc = upd(node[u].lc, l, mid, pos, val);
else node[nu].rc = upd(node[u].rc, mid + 1, r, pos, val);
pull(nu);
return nu;
}
ll qq(int u, int v, int l, int r, int k, ll cur) {
if (l == r) {
return cur + node[u].sum + node[v].sum;
}
int mid = l + r >> 1;
int cnt = node[node[u].rc].cnt + node[node[v].rc].cnt;
ll right_sum = node[node[u].rc].sum + node[node[v].rc].sum;
if (cnt >= k) {
return qq(node[u].rc, node[v].rc, mid + 1, r, k, cur);
} else {
return qq(node[u].lc, node[v].lc, l, mid, k - cnt, cur + right_sum);
}
}
} seg;
static array<int, maxn> roots;
static array<ll, maxn> dp;
static int n, m, start, d;
static ll calc(int l, int r) {
int tmp_d = d;
tmp_d -= r - start;
tmp_d -= r - l;
if (tmp_d <= 0) return 0;
return seg.qq(l == start ? 0 : roots[l], roots[r], 0, m - 1, tmp_d, 0);
}
static void dc(int l, int r, int optl, int optr) {
if (l > r) return;
int mid = l + r >> 1;
ll best = 0; int best_id = -1;
for (int i = optl; i <= min(optr, mid); i++) {
ll cost = calc(i, mid);
if (cost >= best) {
best = cost;
best_id = i;
}
}
// cerr << mid << " transitions from " << best_id << '\n';
dp[mid] = best;
dc(l, mid - 1, optl, best_id);
dc(mid + 1, r, best_id, optr);
}
ll findMaxAttraction(int _n, int _start, int _d, int attraction[]) {
n = _n;
start = _start;
d = _d;
vi attraction_cnt(n), attraction_cnt_id(n);
for (int i = 0; i < n; i++) attraction_cnt[i] = attraction[i];
vi sorted_attraction_cnt(attraction_cnt);
sort(ALL(sorted_attraction_cnt));
sorted_attraction_cnt.erase(unique(ALL(sorted_attraction_cnt)), sorted_attraction_cnt.end());
for (int i = 0; i < n; i++) {
attraction_cnt_id[i] = lower_bound(ALL(sorted_attraction_cnt), attraction_cnt[i]) - sorted_attraction_cnt.begin();
}
m = SZ(sorted_attraction_cnt);
ll ans = 0;
for (int tt = 0; tt < 2; tt++) {
// cerr << "phase " << tt << '\n';
for (int i = 0; i < n; i++) roots[i] = 0;
reverse(ALL(attraction_cnt));
reverse(ALL(attraction_cnt_id));
for (int i = start - 1; i >= 0; i--) {
roots[i] = seg.upd(roots[i + 1], 0, m - 1, attraction_cnt_id[i], attraction_cnt[i]);
}
roots[start] = seg.upd(0, 0, m - 1, attraction_cnt_id[start], attraction_cnt[start]);
for (int i = start + 1; i < n; i++) {
roots[i] = seg.upd(roots[i - 1], 0, m - 1, attraction_cnt_id[i], attraction_cnt[i]);
}
dc(start, n - 1, 0, start);
for (int i = start; i < n; i++) {
// cerr << i << ' ' << dp[i] << '\n';
ans = max(ans, dp[i]);
}
start = n - 1 - start;
}
return ans;
}