Submission #905106

#TimeUsernameProblemLanguageResultExecution timeMemory
905106GrindMachineCake 3 (JOI19_cake3)C++17
100 / 100
1078 ms128980 KiB
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

using namespace std;
using namespace __gnu_pbds;

template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;

#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl

#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)

template<typename T>
void amin(T &a, T b) {
    a = min(a,b);
}

template<typename T>
void amax(T &a, T b) {
    a = max(a,b);
}

#ifdef LOCAL
#include "debug.h"
#else
#define debug(x) 42
#endif

/*



*/

const int MOD = 1e9 + 7;
const int N = 2e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;

vector<int> oa(N);

struct wavelet_tree{
    struct node{
        node *l, *r;
        vector<int> b;
        vector<ll> pref;
    };

    node* root;
    int siz;

    wavelet_tree(){

    }

    node* build(vector<int> &a, vector<int> inds, int l, int r){
        if(l > r) return NULL;
        int mid = (l+r) >> 1;
        vector<int> b(sz(inds)+1);

        vector<int> la,ra;
        rep(i,sz(inds)){
            int x = a[inds[i]];
            b[i+1] = b[i]+(x <= mid);
            if(x <= mid){
                la.pb(inds[i]);
            }
            else{
                ra.pb(inds[i]);
            }
        }

        node* curr = new node();
        curr->b = b;

        vector<ll> pref;
        pref.pb(0);

        trav(i,inds){
            pref.pb(pref.back()+oa[i]);
        }

        curr->pref = pref;

        if(l != r){         
            curr->l = build(a,la,l,mid);
            curr->r = build(a,ra,mid+1,r);
        }

        return curr;
    }

    void build(vector<int> &a, int mx_siz){
        siz = mx_siz;
        vector<int> inds;
        rep(i,sz(a)) inds.pb(i);
        root = build(a,inds,0,siz);
    }

    int kth(node* u, int lx, int rx, int l, int r, int k){
        if(lx == rx){
            return lx;
        }

        int mid = (lx+rx) >> 1;
        auto &b = u->b;
        int cnt = b[r]-b[l-1];

        if(k <= cnt){
            // go left
            return kth(u->l,lx,mid,b[l-1]+1,b[r],k);
        }
        else{
            // go right
            return kth(u->r,mid+1,rx,l-b[l-1],r-b[r],k-cnt);
        }
    }

    int kth(int l, int r, int k){
        return kth(root,0,siz,l+1,r+1,k);
    }

    pll get(node* u, int lx, int rx, int vl, int vr, int l, int r){
        if(lx > vr or rx < vl) return {0,0};
        if(!u) return {0,0};
        if(lx > rx) return {0,0};
        if(lx >= vl and rx <= vr){
            return {r-l+1,u->pref[r]-u->pref[l-1]};
        }

        int mid = (lx+rx) >> 1;
        auto &b = u->b;

        pll p1 = get(u->l,lx,mid,vl,vr,b[l-1]+1,b[r]);
        pll p2 = get(u->r,mid+1,rx,vl,vr,l-b[l-1],r-b[r]);
        return {p1.ff+p2.ff,p1.ss+p2.ss};
    }

    pll get(int l, int r, int k){
        return get(root,0,siz,k,siz,l+1,r+1);
    }
};

vector<pii> a(N);
vector<int> c;
wavelet_tree wt;
ll ans = -inf2;
int n,m;

ll get_cost(int l, int r){
    int len = r-l+1;
    if(len < m) return -inf2;
    
    int k = len-m+1;
    int kth = wt.kth(l,r,k);
    auto [cnt,sum] = wt.get(l,r,kth);
    assert(cnt >= m);
    ll extra = cnt-m;
    sum -= extra*c[kth];
    sum -= 2*(a[r].ff-a[l].ff);

    return sum;
}

void go(int l, int r, int optl, int optr){
    if(l > r) return;
    int mid = (l+r) >> 1;
    ll best = -inf2;
    int opt_mid = optr;

    for(int j = optl; j <= optr; ++j){
        if(mid <= j){
            ll cost = get_cost(mid,j);
            if(cost > best){
                best = cost;
                opt_mid = j;
            }
        }
    }

    amax(ans,best);

    go(l,mid-1,optl,opt_mid);
    go(mid+1,r,opt_mid,optr);
}

void solve(int test_case)
{
    cin >> n >> m;
    rep(i,n) cin >> a[i].ss >> a[i].ff;
    sort(a.begin(),a.begin()+n);

    vector<int> b(n);
    rep(i,n) oa[i] = b[i] = a[i].ss;

    c = b;
    sort(all(c));
    c.resize(unique(all(c))-c.begin());
    int siz = sz(c);

    rep(i,n) b[i] = lower_bound(all(c),b[i])-c.begin();
    rep(i,n) a[i].ss = b[i];
    wt.build(b,n);

    go(0,n-1,0,n-1);
    cout << ans << endl;
}

int main()
{
    fastio;

    int t = 1;
    // cin >> t;

    rep1(i, t) {
        solve(i);
    }

    return 0;
}

Compilation message (stderr)

cake3.cpp: In function 'void solve(int)':
cake3.cpp:219:9: warning: unused variable 'siz' [-Wunused-variable]
  219 |     int siz = sz(c);
      |         ^~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...