#include "holiday.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
// ll opss[20], ops = 0;
struct node {
int cnt;
ll sum;
void merge(node &l, node &r) {
cnt = l.cnt + r.cnt;
sum = l.sum + r.sum;
}
};
struct segtree {
int n;
vector<int> comp;
vector<node> a;
segtree(int n, const vector<int> &i) : n(n), comp(i), a(4 * n) {
assert(is_sorted(i.rbegin(), i.rend()));
}
void upd(int v, int l, int r, int i, int dx, int dy) {
if (l == r) {
assert(l == i);
a[v].cnt += dx;
a[v].sum += dy;
} else {
int m = (l + r) / 2;
if (i <= m) upd(2 * v, l, m, i, dx, dy);
else upd(2 * v + 1, m + 1, r, i, dx, dy);
a[v].merge(a[2 * v], a[2 * v + 1]);
}
}
void upd(int i, int dx, int dy) {
i = n - 1 - (lower_bound(comp.rbegin(), comp.rend(), i) - comp.rbegin());
upd(1, 0, n - 1, i, dx, dy);
}
void add(int x) {
// ops++;
upd(x, 1, x);
}
void rem(int x) {
// ops++;
upd(x, -1, -x);
}
ll walk(int v, int l, int r, int k) {
if (l == r) {
assert(a[v].cnt >= k);
return ll(k) * comp[l];
} else {
int m = (l + r) / 2;
if (k < a[2 * v].cnt) return walk(2 * v, l, m, k);
else return a[2 * v].sum + walk(2 * v + 1, m + 1, r, k - a[2 * v].cnt);
}
}
ll walk(int k) {
if (k <= 0) return 0;
if (a[1].cnt <= k) return a[1].sum;
return walk(1, 0, n - 1, k);
}
};
// 60
// 4436
// 334588796671
// 3389595012736
ll findMaxAttraction(int n, int start, int d, int a[]) {
vector<int> comp;
comp.reserve(n);
for (int i = 0; i < n; i++) {
comp.push_back(a[i]);
}
sort(comp.begin(), comp.end());
comp.erase(unique(comp.begin(), comp.end()), comp.end());
reverse(comp.begin(), comp.end());
segtree tree(comp.size(), comp);
int l = 0, r = -1;
auto shift = [&](int ql, int qr) -> void {
// cout << "shift (" << l << ", " << r << ") -> (" << ql << ", " << qr << ")" << endl;
while (l < ql) tree.rem(a[l++]);
while (l > ql) tree.add(a[--l]);
while (r < qr) tree.add(a[++r]);
while (r > qr) tree.rem(a[r--]);
};
int mode = 1;
auto dnc = [&](int l, int r, int ql, int qr, auto &&self) -> ll {
// cout << "dnc " << l << " " << r << ' ' << ql << ' ' << qr << endl;
pair<ll, int> opt{-1, -1};
int m = (l + r) / 2;
assert(m <= ql);
assert(m <= start && start <= ql);
for (int i = ql; i <= qr; i++) {
shift(m, i);
int days = (3 - mode) * (i - start) + mode * (start - m);
// cout << "range " << ql << ' ' << m << ' ' << days << endl;
opt = max(opt, { tree.walk(d - days), i });
}
// opss[depth] += ops;
// ops = 0;
if (l == r) return opt.first;
else return max({ opt.first, self(l, m, ql, opt.second, self), self(m + 1, r, opt.second, qr, self) });
};
ll ans = dnc(0, start, start, n - 1, dnc);
mode = 2;
ans = max(ans, dnc(0, start, start, n - 1, dnc));
// cout << "ops " << ops << endl;
// for (int i = 0; i < 20; i++) {
// cout << "layer " << i << " " << opss[i] << endl;
// }
// cout << "tot " << accumulate(opss, opss + 20, 0) << endl;
return ans;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |