Submission #1147313

#TimeUsernameProblemLanguageResultExecution timeMemory
1147313Math4Life2020Weirdtree (RMI21_weirdtree)C++20
42 / 100
2118 ms427916 KiB
#include "weirdtree.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long; using pii = pair<ll,ll>;

const ll Nm = (1<<19); const ll E = 19; const ll INF = 1e9+5;
ll N;

map<int,int> stn[2*Nm]; //segtree of numbers
long long sts[2*Nm]; //segtree of sum
ll stm[2*Nm]; //segtree of max
ll st2m[2*Nm]; //segtree of 2nd max
//SEGTREE BEATS BABYYYYY
ll lz[2*Nm]; //lazy tag
ll h[Nm]; //actual values; don't store

void initialise(int _N, int _Q, int _h[]) {
    N = _N;
    for (ll i=0;i<N;i++) {
        h[i]=_h[i+1];
    }
    for (ll i=0;i<Nm;i++) {
        sts[i+Nm]=h[i];
        stm[i+Nm]=h[i];
        st2m[i+Nm]=-INF;
        stn[i+Nm][h[i]]=1;
        stn[i+Nm][-INF]=1;
    }
    for (ll p=(Nm-1);p>=1;p--) {
        for (pii p0: stn[2*p]) {
            stn[p][p0.first]+=p0.second;
        }
        for (pii p0: stn[2*p+1]) {
            stn[p][p0.first]+=p0.second;
        }
        sts[p]=sts[2*p]+sts[2*p+1];
        stm[p]=max(stm[2*p],stm[2*p+1]);
        st2m[p]=max(st2m[2*p],st2m[2*p+1]);
        if (stm[2*p]!=stm[2*p+1]) {
            st2m[p]=max(st2m[p],min(stm[2*p],stm[2*p+1]));
        }
    }
    for (ll p=0;p<(2*Nm);p++) {
        lz[p]=INF;
    }
}

inline ll v2(ll x) {
    return __builtin_ctz(x);
}

inline ll l2(ll x) {
    return (31-__builtin_clz(x));
}

void pdn(ll p0) {
    if (p0==0 || lz[p0]==INF) {
        return;
    }
    //cout << "pdn at p0="<<p0<<"\n";
    while (!stn[p0].empty()) {
        auto A0 = --stn[p0].end();
        pii p1 = *A0;
        if (p1.first>lz[p0]) {
            stn[p0].erase(A0);
            //stm[p0]=lz[p0];
            sts[p0]-=p1.second*(p1.first-lz[p0]);
            stn[p0][lz[p0]]+=p1.second;
        } else {
            break;
        }
    }
    auto A0 = --stn[p0].end();
    stm[p0]=(*A0).first;
    A0--;
    st2m[p0]=(*A0).first;
    if (p0<Nm) {
        lz[2*p0]=min(lz[2*p0],lz[p0]);
        lz[2*p0+1]=min(lz[2*p0+1],lz[p0]);
    }
    lz[p0]=INF;
}

void pdnP(ll p0) {
    return;
    for (ll e=E;e>=0;e--) {
        pdn(p0>>e);
    }
}

void cut2(ll p0, ll k, ll lt) {
    // cout << "cut2: p0,k,lt="<<p0<<","<<k<<","<<lt<<"\n";
    // cout << "curr value: "<<stn[p0][lt]<<"\n";
    assert(k>=0);
    if (k==0 || lt==0 | p0==0) {
        return;
    }
    pdn(p0);
    pdnP(p0);
    assert(stn[p0][lt]>=k);
    //cout << "currF value: "<<stn[p0][lt]<<"\n";
    if (stn[p0][lt]==k) {
        lz[p0]=lt-1; 
    } else {
        assert(p0<Nm);
        pdn(2*p0);
        pdnP(2*p0);
        pdn(2*p0+1);
        pdnP(2*p0+1);
        if (stn[2*p0].find(lt)==stn[2*p0].end()) {
            cut2(2*p0+1,k,lt);
        } else if (stn[2*p0][lt]>=k) {
            cut2(2*p0,k,lt);
        } else {
            cut2(2*p0+1,k-stn[2*p0][lt],lt);
            cut2(2*p0,stn[2*p0][lt],lt);
        }
    }
    if (stn[p0][lt]==k) {
        sts[p0]-=k*lt;
        stn[p0].erase(lt);
    } else {
        stn[p0][lt]-=k;
        sts[p0]-=k*lt;
    }
    stn[p0][lt-1]+=k;
    sts[p0]+=k*(lt-1);
    auto A0 = (--stn[p0].end());
    stm[p0]=(*A0).first;
    A0--;
    st2m[p0]=(*A0).first;
}

void lft(map<ll,ll> m0, ll lt) {
    while (!m0.empty()) {
        pii pir = *(--m0.end()); m0.erase(--m0.end());
        ll p0 = pir.first; ll del = pir.second;
        //cout << "lft: p0,del,lt="<<p0<<","<<del<<","<<lt<<"\n";
        if (p0<1) {
            return;
        }
        sts[p0]+=lt*del;
        stn[p0][lt]+=del;
        //cout << "new val: "<<stn[p0][lt]<<"\n";
        if (stn[p0][lt]==0) {
            stn[p0].erase(stn[p0].find(lt));
        }
        auto A0 = (--stn[p0].end());
        stm[p0]=(*A0).first;
        A0--;
        st2m[p0]=(*A0).first;
        //cout << "new stm: "<<stm[p0]<<"\n";
        m0[p0/2]+=del;
    }
}

void cut(int l, int r, int k) {
    //cout << "CUT CALL\n";
    if (k<=0) {
        return;
    }
    l--; r--;
    for (ll e=E;e>=0;e--) {
        pdn((l>>e)+(1<<(E-e)));
        pdn((r>>e)+(1<<(E-e)));
        pdn((l>>e)+1+(1<<(E-e)));
        pdn((r>>e)-1+(1<<(E-e)));
    }
    ll lt = 0; vector<ll> vp0; vector<ll> vpc;
    ll l2t = 0;
    ll l0 = l; ll r0 = r;
    while (l0<=r0) {
        ll vl = v2(l0); ll vr = v2(r0+1);
        if (vl<vr) {
            ll pc = (l0>>vl)+(1<<(E-vl));
            vpc.push_back(pc);
            pdnP(pc);
            pdn(pc);
            if (stm[pc]>lt) {
                vp0.clear();
                vp0.push_back(pc);
                //l2t = max(l2t,lt);
                lt = stm[pc];
            } else if (stm[pc]==lt) {
                vp0.push_back(pc);
            } else {
                //l2t = max(l2t,stm[pc]);
            }
            //l2t = max(l2t,st2m[pc]);
            l0 += (1<<vl);
        } else {
            ll pc = (r0>>vr)+(1<<(E-vr));
            vpc.push_back(pc);
            pdnP(pc);
            pdn(pc);
            if (stm[pc]>lt) {
                vp0.clear();
                vp0.push_back(pc);
                //l2t = max(l2t,lt);
                lt = stm[pc];
            } else if (stm[pc]==lt) {
                vp0.push_back(pc);
            }
            //l2t = max(l2t,st2m[pc]);
            r0 -= (1<<vr);
        }
    }
    if (lt==0) {
        return;
    }
    for (ll p0: vpc) {
        if (stm[p0]==lt) {
            l2t = max(st2m[p0],l2t);
        } else {
            l2t = max(stm[p0],l2t);
        }
    }
    vector<pii> vpn; 
    for (ll p0: vp0) {
        ll lp = l2(p0);
        ll xn = (p0-(1<<lp))<<(E-lp);
        vpn.push_back({xn,p0});
    }
    sort(vpn.begin(),vpn.end());
    vp0.clear();
    for (pii p1: vpn) {
        vp0.push_back(p1.second);
    }
    // for (ll p0: vp0) {
    //     cout << "p0 in vp0: "<<p0<<"\n";
    // }
    ll PDEL = -1;
    map<ll,ll> lftm;
    //cout << "lt="<<lt<<"\n";
    //cout << stn[2][2]<<"\n";
    ll nsum = 0;
    for (ll p0: vp0) {
        if (stn[p0].find(lt)==stn[p0].end()) {
            continue;
        }
        nsum += stn[p0][lt];
    }
    ll REV=0;
    REV = min(k/nsum,lt-l2t);
    if (REV>0) {
        map<ll,ll> lftm2;
        for (ll p0: vp0) {
            //cout << "k="<<k<<", p0="<<p0<<", lt="<<lt<<", stn[p0][lt]="<<stn[p0][lt]<<"\n";
            if (stn[p0].find(lt)==stn[p0].end()) {
                continue;
            }
            lftm[p0]+= -stn[p0][lt];
            lftm2[p0]+= stn[p0][lt];
            k -= REV*stn[p0][lt];
            lz[p0]=min(lz[p0],lt-REV);
        }
        lft(lftm2,lt-REV);
        lft(lftm,lt);
        for (ll p0: vp0) {
            assert(stm[p0]>=(lt-REV));
        }
        cut(l+1,r+1,k);
        return;
    }
    for (ll p0: vp0) {
        //cout << "k="<<k<<", p0="<<p0<<", lt="<<lt<<", stn[p0][lt]="<<stn[p0][lt]<<"\n";
        if (stn[p0][lt]<=k) {
            lftm[p0]+= -stn[p0][lt];
            k -= stn[p0][lt];
            lz[p0]=min(lz[p0],lt-1);
        } else {
            PDEL = p0;
            break;
        }
        if (k==0) {
            break;
        }
    }
    if (PDEL!=-1) {
        cut2(PDEL,k,lt);
        lftm[PDEL/2] += -k;
        k=0;
    }
    map<ll,ll> lftm2;
    for (pii p0: lftm) {
        lftm2[p0.first]=-p0.second;
    }
    lft(lftm2,lt-1);
    lft(lftm,lt);
    for (ll p0: vp0) {
        pdn(p0);
    }
    if (PDEL==-1 && k>0) {
        cut(l+1,r+1,k);
    }
}

void magic(int i, int x) {
    i--;
    for (ll e=E;e>=0;e--) {
        pdn((i>>e)+(1<<(E-e)));
    }
    ll h0 = stm[i+Nm];
    map<ll,ll> m0;
    m0[i+Nm]=1;
    lft(m0,x);
    map<ll,ll> m1;
    m1[i+Nm]=-1;
    lft(m1,h0);
}

long long inspect(int l, int r) {
    l--; r--;
    long long fval = 0;
    for (ll e=E;e>=0;e--) {
        pdn((l>>e)+(1<<(E-e)));
        pdn((r>>e)+(1<<(E-e)));
        pdn((l>>e)+1+(1<<(E-e)));
        pdn((r>>e)-1+(1<<(E-e)));
    }
    while (l<=r) {
        ll vl = v2(l); ll vr = v2(r+1);
        if (vl<vr) {
            pdnP((l>>vl)+(1<<(E-vl)));
            pdn((l>>vl)+(1<<(E-vl)));
            fval += sts[(l>>vl)+(1<<(E-vl))];
            //cout << "l,vl="<<l<<","<<vl<<"\n";
            //cout << "ins="<<((l>>vl)+(1<<(E-vl)))<<", outs="<<sts[(l>>vl)+(1<<(E-vl))]<<"\n";
            l += (1<<vl);
        } else {
            pdnP((r>>vr)+(1<<(E-vr)));
            pdn((r>>vr)+(1<<(E-vr)));
            fval += sts[(r>>vr)+(1<<(E-vr))];
            //cout << "r,vr="<<r<<","<<vr<<"\n";
            //cout << "ins="<<((r>>vr)+(1<<(E-vr)))<<", outs="<<sts[(r>>vr)+(1<<(E-vr))]<<"\n";
            r -= (1<<vr);
        }
    }
    return fval;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...