Submission #1126273

#TimeUsernameProblemLanguageResultExecution timeMemory
1126273dwuyWeirdtree (RMI21_weirdtree)C++20
100 / 100
412 ms50676 KiB
#include "weirdtree.h"
#include <bits/stdc++.h>
#define ll long long
using namespace std;

struct Node{
    ll fm, sm, cm, lz, sum;
    Node() : fm(0), sm(-1), cm(0), lz(0), sum(0) {}
};

struct SMT{
    ll n;
    vector<Node> tree;
    SMT(ll n = 0) : n(n), tree(n<<2|3, Node()) {};

    void down(ll id){
        if(tree[id].lz == 0) return;
        ll delta = tree[id].lz;
        if(tree[id<<1].fm + delta == tree[id].fm){
            tree[id<<1].lz += delta;
            tree[id<<1].fm += delta;
            tree[id<<1].sum += delta * tree[id<<1].cm;
        }
        if(tree[id<<1|1].fm + delta == tree[id].fm){
            tree[id<<1|1].lz += delta;
            tree[id<<1|1].fm += delta;
            tree[id<<1|1].sum += delta * tree[id<<1|1].cm;
        }
        tree[id].lz = 0;
    }

    void update(ll id){
        Node &L = tree[id<<1];
        Node &R = tree[id<<1|1];
        tree[id].sum = L.sum + R.sum;
        tree[id].fm = max(L.fm, R.fm);
        tree[id].cm = 0;
        if(tree[id].fm == L.fm) tree[id].cm += L.cm;
        if(tree[id].fm == R.fm) tree[id].cm += R.cm;
        if(L.fm != R.fm) tree[id].sm = max({L.sm, R.sm, min(L.fm, R.fm)});
        else tree[id].sm = max(L.sm, R.sm);
    }

    Node combine(Node L, Node R){
        Node res;
        res.sum = L.sum + R.sum;
        res.fm = max(L.fm, R.fm);
        res.cm = 0;
        if(res.fm == L.fm) res.cm += L.cm;
        if(res.fm == R.fm) res.cm += R.cm;
        if(L.fm != R.fm) res.sm = max({L.sm, R.sm, min(L.fm, R.fm)});
        else res.sm = max(L.sm, R.sm);
        return res;
    }

    void assign(ll pos, ll val){
        ll id = 1;
        for(ll lo=1, hi=n; lo<hi;){
            ll mid = (lo + hi)>>1;
            down(id);
            if(pos <= mid) id = id<<1, hi = mid;
            else lo = mid + 1, id = id<<1 | 1;
        }
        tree[id].fm = tree[id].sum = val;
        tree[id].cm = 1;
        for(id>>=1; id; id>>=1) update(id);
    }

    Node gmax(ll l, ll r, ll id, const ll &u, const ll &v){
        if(l > v || r < u) return Node();
        if(l >= u && r <= v) return tree[id];
        down(id);
        ll mid = (l + r)>>1;
        return combine(gmax(l, mid, id<<1, u, v), gmax(mid + 1, r, id<<1|1, u, v));
    }

    Node gmax(ll l, ll r){
        return gmax(1, n, 1, l, r);
    }

    ll gsum(ll l, ll r, ll id, const ll &u, const ll &v){
        if(l > v || r < u) return 0;
        if(l >= u && r <= v) return tree[id].sum;
        down(id);
        ll mid = (l + r)>>1;
        return gsum(l, mid, id<<1, u, v) + gsum(mid + 1, r, id<<1|1, u, v);
    }

    ll gsum(ll l, ll r){
        return gsum(1, n, 1, l, r);
    }

    // ll gmid(ll l, ll r, ll id, const ll &u, const ll &v, const ll &val){
    //     if(l > v || r < u || tree[id].fm <= val) return 0;
    //     if(l >= u && r <= v && tree[id].sm < val) return tree[id].cm*(tree[id].fm - val);
    //     down(id);
    //     ll mid = (l + r)>>1;
    //     return gmid(l, mid, id<<1, u, v, val) + gmid(mid + 1, r, id<<1|1, u, v, val);
    // }

    // ll gmid(ll l, ll r, ll val){
    //     return gmid(1, n, 1, l, r, val);
    // }

    void zmin(ll l, ll r, ll id, const ll &u, const ll &v, const ll &val){
        if(l > v || r < u || tree[id].fm <= val) return;
        if(l >= u && r <= v && tree[id].sm < val){
            tree[id].sum -= tree[id].cm * (tree[id].fm - val);
            tree[id].lz += val - tree[id].fm;
            tree[id].fm = val;
            return;
        }
        down(id);
        ll mid = (l + r)>>1;
        zmin(l, mid, id<<1, u, v, val);
        zmin(mid + 1, r, id<<1|1, u, v, val);
        update(id);
    }
    
    void zmin(ll l, ll r, ll val){
        zmin(1, n, 1, l, r, val);
    }

    void pmin(ll l, ll r, ll id, const ll &u, const ll &v, const ll &mx, ll &k){
        if(l > v || r < u || k == 0 || tree[id].fm < mx) return;
        if(l >= u && r <= v && tree[id].sm < tree[id].fm - 1 && tree[id].fm == mx && tree[id].cm <= k){
            k -= tree[id].cm;
            tree[id].sum -= tree[id].cm;
            tree[id].lz += -1;
            tree[id].fm--;
            // cout << l << ' ' << r << " - " << tree[id].fm << ' ' << tree[id].cm << ' ' << tree[id].sum << ' ' << tree[id].lz <<  endl;
            return;
        }
        down(id);
        ll mid = (l + r)>>1;
        pmin(l, mid, id<<1, u, v, mx, k);
        pmin(mid + 1, r, id<<1|1, u, v, mx, k);
        update(id);   
    }

    void pmin(ll l, ll r, ll k){
        pmin(1, n, 1, l, r, gmax(l, r).fm, k);
    }
} smt;

ll n, q;
void initialise(int N, int Q, int h[]){
    n = N;
    q = Q;
    smt = SMT(n);
    for(ll i=1; i<=n; i++) smt.assign(i, h[i]);
}

void cut(int l, int r, int k){
    while(k){
        Node cur = smt.gmax(l, r);
        if(cur.fm == 0){
            k = 0;
            break;
        }
        cur.sm = max(cur.sm, 0LL);
        if(cur.cm*(cur.fm - cur.sm) <= k){
            k -= cur.cm*(cur.fm - cur.sm);
            smt.zmin(l, r, cur.sm);
        }
        else{
            int d = k/cur.cm;
            smt.zmin(l, r, cur.fm - d);
            k %= cur.cm;
            break;
        }
    }
    if(k) smt.pmin(l, r, k);
}

void magic(int p, int x){
    smt.assign(p, x);
}

ll inspect(int l, int r){
    return smt.gsum(l, r);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...