This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include "holiday.h"
#include <bits/stdc++.h>
#define int ll
using namespace std;
typedef long long ll;
const int N = 2e5 + 10;
const int Log = 20;
const ll Inf = 1e18;
int n, st, d, a[N], I[N], ord[N];
ll sum[N], cnt[N];
ll ans = 0;
void Add(int id, int z){
//cerr << "!! " << id << ' ' << z << '\n';
int pos = ord[id];
for(; pos < N; pos += (pos & (-pos)))
cnt[pos] += z, sum[pos] += a[id] * z;
}
ll Get(int X){
//cerr << "## " << X << '\n';
int res = 0;
ll sm = 0;
for(int l = Log - 1; l >= 0; l--){
if(res + (1 << l) >= N) continue;
if(cnt[res + (1 << l)] > X) continue;
res |= (1 << l);
X -= cnt[res];
sm += sum[res];
}
return sm;
}
void Solve(int L, int R, int l, int r){
if(L >= R) return ;
int mid = (L + R) >> 1;
int rm2, rem = max(0ll, d - (st - mid));
int Ln = L, Rn = l;
for(int i = Ln; i < min(mid, Rn); i++) Add(i, -1);
Ln = min(mid, Rn); //Rn = max(Rn, mid);
ll val, mx = -Inf, opt = l;
for(int i = Rn; i <= r; i++){
if(i < mid){
Ln ++; Rn ++;
continue;
} else {
if(i > mid){
rm2 = max(0ll, rem - max(0ll, (i - mid - 1)) );
val = Get(rm2);
if(val > mx){
mx = val;
opt = i;
}
}
}
if(i != r){
Add(Rn, 1);
Rn ++;
}
}
//Rn = r;
assert(Rn == r);
ans = max(ans, mx);
//cerr << "$$ " << mid << ' ' << mx << '\n';
int Lrq = mid + 1, Rrq = opt;
while(Lrq < Ln){ Add(Ln - 1, 1); Ln --; }
while(Rn < Rrq){ Add(Rn, 1); Rn ++; }
while(Ln < Lrq){ Add(Ln, -1); Ln ++; }
while(Rrq < Rn){ Add(Rn - 1, -1); Rn --; }
Solve(mid + 1, R, opt, r);
Lrq = L, Rrq = l;
while(Lrq < Ln){ Add(Ln - 1, 1); Ln --; }
while(Rn < Rrq){ Add(Rn, 1); Rn ++; }
while(Ln < Lrq){ Add(Ln, -1); Ln ++; }
while(Rrq < Rn){ Add(Rn - 1, -1); Rn --; }
Solve(L, mid, l, opt);
}
ll findMaxAttraction(int32_t _n, int32_t _st, int32_t _d, int32_t _a[]) {
n = _n; st = _st; d = _d;
for(int i = 0; i < n; i++) a[i] = _a[i];
iota(I, I + n, 0);
sort(I, I + n, [&](int i, int j){ return a[i] > a[j]; });
for(int i = 0; i < n; i++) ord[I[i]] = i + 1;
Solve(0, st + 1, 0, n);
//cerr << "###################\n";
reverse(a, a + n);
st = n - 1 - st;
iota(I, I + n, 0);
sort(I, I + n, [&](int i, int j){ return a[i] > a[j]; });
for(int i = 0; i < n; i++) ord[I[i]] = i + 1;
Solve(0, st + 1, 0, n);
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |