This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int, int>;
using pli = pair<ll, int>;
#define pb push_back
#define ff first
#define ss second
pli operator + (pli x, pli y){
return {x.ff + y.ff, x.ss + y.ss};
}
pli operator - (pli x, pli y){
return {x.ff - y.ff, x.ss - y.ss};
}
struct PST{
struct node{
node *l, *r;
ll s; int c;
node(){
l = r = 0;
s = c = 0;
}
node(ll x, int y){
l = r = 0;
s = x;
c = y;
}
node(node *ls, node *rs){
l = ls; r = rs;
s = l -> s + r -> s;
c = l -> c + r -> c;
}
};
vector<node*> root;
node *build(int tl, int tr){
if (tl == tr) return new node();
int tm = (tl + tr) / 2;
return new node(build(tl, tm), build(tm + 1, tr));
}
int n, cc;
vector<int> a;
vector<int> :: iterator it;
PST(int ns, vector<int> as){
root.resize(ns + 1);
sort(as.begin() + 1, as.end());
a = {0};
int i = 1;
while (i < as.size()){
int j = i;
while (j < as.size() && as[i] == as[j]){
j++;
}
a.pb(as[i]);
i = j;
}
n = (int) a.size() - 1;
root[0] = build(1, n);
cc = 0;
}
node* upd(node *v, int tl, int tr, int& p, int& x){
if (tl == tr) return new node(v -> s + x, v -> c + 1);
int tm = (tl + tr) / 2;
if (p <= tm){
return new node(upd(v -> l, tl, tm, p, x), v -> r);
}
else {
return new node(v -> l, upd(v -> r, tm + 1, tr, p, x));
}
}
void upd(int x){
cc++;
it = lower_bound(a.begin() + 1, a.end(), x);
int i = (int) (it - a.begin());
root[cc] = upd(root[cc - 1], 1, n, i, x);
}
pli sum(node *v, int tl, int tr, int& l, int& r){
if (l > tr || r < tl) return {0, 0};
if (l <= tl && tr <= r) return {v -> s, v -> c};
int tm = (tl + tr) / 2;
return sum(v -> l, tl, tm, l, r) + sum(v -> r, tm + 1, tr, l, r);;
}
pli sum(int k, int l, int r){
return sum(root[k], 1, n, l, r);
}
ll find(node *v1, node *v2, int tl, int tr, int k){
if (tl == tr) return 1LL * a[tl] * k;
int s = (v2 -> l -> c) - (v1 -> l -> c), tm = (tl + tr) / 2;
ll sum = (v2 -> l -> s) - (v1 -> l -> s);
if (k <= s){
return find(v1 -> l, v2 -> l, tl, tm, k);
}
return sum + find(v1 -> r, v2 -> r, tm + 1, tr, k - s);
}
ll get(int l, int r, int k){
k = min(k, r - l + 1);
return find(root[l - 1], root[r], 1, n, k);
}
};
ll get(vector<int> a, int n, int x, int d){
PST T(n, a);
for (int i = 1; i <= n; i++){
T.upd(a[i]);
}
auto f = [&](int l, int r){
int f = d - (x - l) - (r - l);
return -T.get(l, r, f);
};
ll out = 0;
function<void(int, int, int, int)> solve = [&](int l, int r, int l1, int r1){
if (l > r) return;
int m = (l + r) / 2;
pli mx = {-1, 0};
for (int i = l1; i <= r1; i++){
mx = max(mx, {f(m, i), -i});
}
out = max(out, mx.ff);
mx.ss = -mx.ss;
solve(l, m - 1, l1, mx.ss);
solve(m + 1, r, mx.ss, r1);
};
solve(1, x, x, n);
return out;
}
ll findMaxAttraction(int n, int x, int d, int A[]){
vector<int> a(n + 1);
for (int i = 1; i <= n; i++){
a[i] = -A[i - 1];
}
x++;
ll out = get(a, n, x, d);
reverse(a.begin() + 1, a.end());
x = n + 1 - x;
out = max(out, get(a, n, x, d));
return out;
}
Compilation message (stderr)
holiday.cpp: In constructor 'PST::PST(int, std::vector<int>)':
holiday.cpp:51:18: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
51 | while (i < as.size()){
| ~~^~~~~~~~~~~
holiday.cpp:53:22: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
53 | while (j < as.size() && as[i] == as[j]){
| ~~^~~~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |