Submission #1235827

#TimeUsernameProblemLanguageResultExecution timeMemory
1235827kl0989e휴가 (IOI14_holiday)C++20
100 / 100
1046 ms21168 KiB
#include "holiday.h"
#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define pi pair<int, int>
#define pl pair<ll, ll>
#define vi vector<int>
#define vl vector<ll>
#define fi first
#define se second
#define pb push_back
#define all(x) (x).begin(),(x).end()

struct segTree {
    struct node {
        int act=1;
        ll val=0;
        ll tval=0;

        node(int _act=1, ll _val=0): act(_act), val(_val), tval(_val) {}
    };
    node unite(node a, node b) {
        return {a.act+b.act,a.val+b.val};
    }

    vector<node> nodes;
    int sze;
    void init(int n, vi& val) {
        sze=n;
        nodes.resize(4*sze);
        build(1,0,n-1,val);
    }

    void build(int v, int tl, int tr, vi& val) {
        if (tl==tr) {
            nodes[v]=node(1,val[tl]);
            return;
        }
        int tm=tl+(tr-tl)/2;
        build(2*v,tl,tm,val);
        build(2*v+1,tm+1,tr,val);
        nodes[v]=unite(nodes[2*v],nodes[2*v+1]);
    }

    void chstate(int v, int tl, int tr, int ind, int state) {
        if (ind<tl || tr<ind) {
            return;
        }
        if (tl==tr) {
            nodes[v].act=state;
            nodes[v].val=nodes[v].tval*state;
            return;
        }
        int tm=tl+(tr-tl)/2;
        chstate(2*v,tl,tm,ind,state);
        chstate(2*v+1,tm+1,tr,ind,state);
        nodes[v]=unite(nodes[2*v],nodes[2*v+1]);
    }
    void chstate(int ind, int state) {
        chstate(1,0,sze-1,ind,state);
    }

    node get(int v, int tl, int tr, int x) {
        if (nodes[v].act<=x) {
            return nodes[v];
        }
        int tm=tl+(tr-tl)/2;
        if (nodes[2*v].act>=x) {
            return get(2*v,tl,tm,x);
        }
        return unite(nodes[2*v],get(2*v+1,tm+1,tr,x-nodes[2*v].act));
    }
    node get(int x) {
        if (x<=0) {
            return node(0,0);
        }
        return get(1,0,sze-1,x);
    }
};

const int maxn=1e5+10;

vi inds(maxn);
vi getind(maxn);
segTree dat;
int strt;

vector<pl> ans1(3*maxn),ans2(3*maxn);

pl findans(int d, int tl, int tr, int mult) {
    for (int i=tl; i<=tr; i++) {
        dat.chstate(getind[i],1);
    }
    ll ans=-1;
    int ind=-1;
    for (int i=tr; i>=tl; i--) {
        ll t=dat.get(d-mult*(i-strt)).val;
        dat.chstate(getind[i],0);
        if (t>=ans) {
            ans=t;
            ind=i;
        }
    }
    return {ans,ind};
}

void div(int l, int r, int tl, int tr, int mult) {
    if (r<l) {
        return;
    }
    if (l==r) {
        ans2[l]=findans(l,tl,tr,mult);
        return;
    }
    int m=l+(r-l)/2;
    ans2[m]=findans(m,tl,tr,mult);
    div(l,m-1,tl,ans2[m].se,mult);
    for (int i=tl; i<=ans2[m].se; i++) {
        dat.chstate(getind[i],1);
    }
    div(m+1,r,ans2[m].se,tr,mult);
}

bool other=1;

ll findMaxAttraction(int n, int _strt, int d, int att[]) {
    strt=_strt;
    iota(inds.begin(),inds.begin()+n,0);
    sort(inds.begin(),inds.begin()+n,[&](int a, int b){return att[a]>att[b];});
    for (int i=0; i<n; i++) {
        getind[inds[i]]=i;
    }
    vi val(att,att+n);
    sort(all(val),[](int a, int b){return a>b;});
    dat.init(n,val);
    for (int i=0; i<strt; i++) {
        dat.chstate(getind[i],0);
    }
    div(0,d,strt,n-1,2);
    if (strt==0) {
        ll ans=ans2[d].fi;
        if (other) {
            other=0;
            reverse(att,att+n);
            strt=n-1-strt;
            ans=max(ans,findMaxAttraction(n,strt,d,att));
        }
        return ans;
    }
    swap(ans1,ans2);
    reverse(att,att+n);
    sort(inds.begin(),inds.begin()+n,[&](int a, int b){return att[a]>att[b];});
    strt=n-1-strt+1;
    for (int i=0; i<n; i++) {
        getind[inds[i]]=i;
    }
    dat.init(n,val);
    for (int i=0; i<strt; i++) {
        dat.chstate(getind[i],0);
    }
    div(0,d,strt,n-1,1);
    strt=n-strt;
    ll ans=0;
    for (int i=0; i<=d; i++) {
        ll t=ans1[i].fi+ans2[max(d-1-i,0)].fi;
        ans=max(ans,t);
    }
    if (other) {
        other=0;
        strt=n-1-strt;
        ans=max(ans,findMaxAttraction(n,strt,d,att));
    }
    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...