This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
#define ll long long
#define pb push_back
#define mp make_pair
#define P pair<ll,ll>
using namespace std;
const int maxn = 200100;
ll n, m, arr[maxn];
vector<P>v;
struct node {
public:
node *lc, *rc;
ll cnt;
ll sum;
node(ll _cnt, ll _sum) {
lc = rc = NULL;
cnt = _cnt; sum = _sum;
}
node(node *a, node *b) {
sum = a->sum + b->sum;
cnt = a->cnt + b->cnt;
lc = a;
rc = b;
}
};
ll val[maxn];
set<ll>st;
map<ll,int>ind;
node* pref[maxn];
node* build(int li=0, int ri=n) {
if(li == ri) return new node(0LL, 0LL);
else {
int mid = (li + ri) / 2;
return new node(build(li, mid), build(mid+1, ri));
}
}
node* insert_number(node *curr, int pos, int sum_val, int li=0, int ri=n) {
if(li == ri) {
return new node(curr->cnt+1, curr->sum+sum_val);
}
else {
int mid = (li + ri) / 2;
if(pos <= mid)
return new node(insert_number(curr->lc, pos, sum_val, li, mid), curr->rc);
else
return new node(curr->lc, insert_number(curr->rc, pos, sum_val, mid+1, ri));
}
}
ll solve(node *l, node *r, int k, int li=0, int ri=n) {
if(k == 0) return 0LL;
ll total_cnt = r->cnt - l->cnt;
ll total_sum = r->sum - l->sum;
//cout<<"["<<li<<" "<<ri<<"] -> "<<total_cnt<<", "<<total_sum<<"\n";
if(k == total_cnt) return total_sum;
if(li == ri) return k * val[li];
else {
int mid = (li + ri) / 2;
ll total = r -> rc -> cnt - l -> rc -> cnt;
ll sum_total = r -> rc -> sum - l -> rc -> sum;
if(total >= k) return solve(l->rc, r->rc, k, mid+1, ri);
else return sum_total + solve(l->lc, r->lc, k-total, li, mid);
}
}
int main() {
cin>>n>>m;
ll a, b;
for(int i=0;i<n;i++) {
cin>>a>>b;
v.pb(mp(b, a));
st.insert(a);
}
sort(v.begin(), v.end());
int br = 0;
for(ll i:st) {
val[br] = i;
ind[i] = br++;
}
node *root = build();
for(int i=1;i<=n;i++) {
root = insert_number(root, ind[v[i-1].second], v[i-1].second);
pref[i] = root;
}
ll result = LLONG_MIN;
for(int i=1;i<=n;i++) {
if(i < m) continue;
//cout<<i<<": \n";
for(int j=i-1;j>=1;j--) {
if(i-j+1<m) continue;
ll sum = v[i-1].second + v[j-1].second;
ll cost = 2*(v[i-1].first - v[j-1].first);
//cout<<j<<": "<<sum<<" - "<<cost<<" with maximum elements "<<solve(pref[j], pref[i-1], m-2)<<"\n";
result = max(result, solve(pref[j], pref[i-1], m-2) + sum - cost);
}
}
cout<<result<<"\n";
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |