Submission #1297559

#TimeUsernameProblemLanguageResultExecution timeMemory
1297559danglayloi1Cake 3 (JOI19_cake3)C++20
100 / 100
1530 ms199624 KiB
#include <bits/stdc++.h>
#define ii pair<int, int>
#define fi first
#define se second
#define inf 0x3f3f3f3f3f3f3f3f
using namespace std;
using ll = long long;
const ll mod=1e9+7;
const int nx=2e5+5;
const int base=1e9;
bool cmp(ii a, ii b)
{
    return a.se<b.se;
}
struct PST
{
    struct dak
    {
        ll sum=0;
        int cnt=0, l=0, r=0;
    };
    int cur=1;
    vector<dak> nod;
    void add()
    {
        nod.emplace_back();
    }
    dak join(int l, int r)
    {
        return dak{nod[l].sum+nod[r].sum, nod[l].cnt+nod[r].cnt, l, r};
    }
    int update(int id, int l, int r, int p)
    {
        if(l==r)
        {
            add();
            nod[cur]=nod[id];
            nod[cur].sum+=p;
            nod[cur].cnt++;
            return cur++;
        }
        int mid=(l+r)>>1;
        add();
        if(p<=mid) nod[cur]=join(update(nod[id].l, l, mid, p), nod[id].r);
        else nod[cur]=join(nod[id].l, update(nod[id].r, mid+1, r, p));
        return cur++;
    }
    ll find(int idl, int idr, int l, int r, int k)
    {
        if(l==r) return 1ll*min(k, nod[idr].cnt-nod[idl].cnt)*l;
        int mid=(l+r)>>1;
        int le=nod[idl].r, ri=nod[idr].r;
        if(nod[ri].cnt-nod[le].cnt>=k) return find(le, ri, mid+1, r, k);
        return nod[ri].sum-nod[le].sum+find(nod[idl].l, nod[idr].l, l, mid, k-nod[ri].cnt+nod[le].cnt);
    }
} st;
int n, k, root[nx];
ii a[nx];
ll res=-inf;
ll cost(int l, int r)
{
    if(l<0 || r>n) return -inf;
    if(r-l+1<k) return -inf;
    return st.find(root[l-1], root[r], 1, base, k)-2ll*(a[r].se-a[l].se);
}
void solve(int l, int r, int optl, int optr)
{
    if(l>r) return;
    int mid=(l+r)>>1;
    int pos=0;
    ll cur=-inf-1;
    for(int i = optl; i <= min(mid, optr); i++)
        if(cost(i, mid)>cur)
            cur=cost(i, mid), pos=i;
    res=max(res, cur);
    solve(l, mid-1, optl, pos);
    solve(mid+1, r, pos, optr);
}
int main()
{
    ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    st.add();
    st.add();
    cin>>n>>k;
    for(int i = 1; i <= n; i++)
        cin>>a[i].fi>>a[i].se;
    sort(a+1, a+n+1, cmp);
    root[0]=1;
    for(int i = 1; i <= n; i++)
        root[i]=st.update(root[i-1], 1, base, a[i].fi);
    solve(1, n, 1, n);
    cout<<res;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...