Submission #112557

#TimeUsernameProblemLanguageResultExecution timeMemory
112557someone_aaCake 3 (JOI19_cake3)C++17
100 / 100
2044 ms226908 KiB
#include <bits/stdc++.h>
#define ll long long
#define pb push_back
#define mp make_pair
#define P pair<ll,ll>
using namespace std;
const int maxn = 200100;
ll n, m, arr[maxn];
vector<P>v;

struct node {
public:
    node *lc, *rc;
    ll cnt;
    ll sum;

    node(ll _cnt, ll _sum) {
        lc = rc = NULL;
        cnt = _cnt; sum = _sum;
    }
    node(node *a, node *b) {
        sum = a->sum + b->sum;
        cnt = a->cnt + b->cnt;
        lc = a;
        rc = b;
    }
};

ll val[maxn], result;
set<ll>st;
map<ll,int>ind;

node* pref[maxn];


node* build(int li=0, int ri=n) {
    if(li == ri) return new node(0LL, 0LL);
    else {
        int mid = (li + ri) / 2;
        return new node(build(li, mid), build(mid+1, ri));
    }
}

node* insert_number(node *curr, int pos, int sum_val, int li=0, int ri=n) {
    if(li == ri) {
        return new node(curr->cnt+1, curr->sum+sum_val);
    }
    else {
        int mid = (li + ri) / 2;

        if(pos <= mid)
            return new node(insert_number(curr->lc, pos, sum_val, li, mid), curr->rc);
        else
            return new node(curr->lc, insert_number(curr->rc, pos, sum_val, mid+1, ri));
    }
}

ll solve(node *l, node *r, int k, int li=0, int ri=n) {
    if(k == 0) return 0LL;
    ll total_cnt = r->cnt - l->cnt;
    ll total_sum = r->sum - l->sum;

    //cout<<"["<<li<<" "<<ri<<"] -> "<<total_cnt<<", "<<total_sum<<"\n";

    if(k == total_cnt) return total_sum;
    if(li == ri) return k * val[li];
    else {
        int mid = (li + ri) / 2;
        ll total = r -> rc -> cnt - l -> rc -> cnt;
        ll sum_total = r -> rc -> sum - l -> rc -> sum;
        if(total >= k) return solve(l->rc, r->rc, k, mid+1, ri);
        else return sum_total + solve(l->lc, r->lc, k-total, li, mid);
    }
}

// i < j
ll f(int i, int j) {
    if(i >= j) return LLONG_MIN;
    if(j - i + 1 < m) return LLONG_MIN;
    ll sum = v[i-1].second + v[j-1].second;
    ll cost = 2*(v[j-1].first - v[i-1].first);
    return solve(pref[i], pref[j-1], m-2) + sum - cost;
}

void fsolve(int l, int r, int optl, int optr) {
    if(l > r) return;
    int mid = (l + r) / 2;

    ll temp_cost = LLONG_MIN;
    ll temp_ind = -1;

    if(n - mid + 1 < m) {
        temp_ind = n;
    }
    else {
        for(int i=max(mid, optl);i<=optr;i++) {
            ll temp = f(mid, i);
            if(temp > temp_cost) {
                temp_cost = temp;
                temp_ind = i;
            }
        }
    }

    result = max(result, temp_cost);

    fsolve(l, mid-1, optl, temp_ind);
    fsolve(mid+1,r, temp_ind, optr);
}

int main() {
    cin>>n>>m;
    ll a, b;
    for(int i=0;i<n;i++) {
        cin>>a>>b;
        v.pb(mp(b, a));
        st.insert(a);
    }

    sort(v.begin(), v.end());
    int br = 0;
    for(ll i:st) {
        val[br] = i;
        ind[i] = br++;
    }

    node *root = build();
    for(int i=1;i<=n;i++) {
        root = insert_number(root, ind[v[i-1].second], v[i-1].second);
        pref[i] = root;
    }
    result = LLONG_MIN;
    fsolve(1, n, 1, n);

    cout<<result<<"\n";
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...