Submission #1219229

#TimeUsernameProblemLanguageResultExecution timeMemory
1219229green_gold_dogWeirdtree (RMI21_weirdtree)C++20
0 / 100
273 ms60512 KiB
//#pragma GCC optimize("Ofast")
//#pragma GCC target("avx,avx2,sse,sse2,sse3,ssse3,sse4,abm,popcnt,mmx")
#include <bits/stdc++.h>
#include "weirdtree.h"

using namespace std;

typedef long long ll;
typedef double db;
typedef long double ldb;
typedef complex<double> cd;

constexpr ll INF64 = 9'000'000'000'000'000'000, INF32 = 2'000'000'000, MOD = 1'000'000'007;
constexpr db PI = acos(-1);
constexpr bool IS_FILE = false, IS_TEST_CASES = false;

random_device rd;
mt19937 rnd32(rd());
mt19937_64 rnd64(rd());

template<typename T>
bool assign_max(T& a, T b) {
        if (b > a) {
                a = b;
                return true;
        }
        return false;
}

template<typename T>
bool assign_min(T& a, T b) {
        if (b < a) {
                a = b;
                return true;
        }
        return false;
}

template<typename T>
T square(T a) {
        return a * a;
}

template<>
struct std::hash<pair<ll, ll>> {
        ll operator() (pair<ll, ll> p) const {
                return ((__int128)p.first * MOD + p.second) % INF64;
        }
};

tuple<ll, ll, ll, ll, ll> merge(tuple<ll, ll, ll, ll, ll> t1, tuple<ll, ll, ll, ll, ll> t2) {
        ll mx = 0, cmx = 0, smx = 0, sum = get<3>(t1) + get<3>(t2);
        if (get<0>(t1) > get<0>(t2)) {
                mx = get<0>(t1);
                cmx = get<1>(t1);
                smx = max(get<2>(t1), get<0>(t2));
        }
        if (get<0>(t1) < get<0>(t2)) {
                mx = get<0>(t2);
                cmx = get<1>(t2);
                smx = max(get<2>(t2), get<0>(t1));
        }
        if (get<0>(t1) == get<0>(t2)) {
                mx = get<0>(t1);
                cmx = get<1>(t1) + get<1>(t2);
                smx = max(get<2>(t2), get<2>(t1));
        }
        return make_tuple(mx, cmx, smx, sum, mx);
}

struct segment_tree {
        vector<tuple<ll, ll, ll, ll, ll>> arr;
        vector<ll> ml, mr;
        ll sz;
        segment_tree(ll n = 0) {
                sz = 1;
                while (sz < n) {
                        sz *= 2;
                }
                arr.resize(sz * 2, make_tuple(0, 0, 0, 0, 0));
                ml.resize(sz * 2, 0);
                mr.resize(sz * 2, 0);
        }
        void pull(ll v) {
                arr[v] = merge(arr[v * 2], arr[v * 2 + 1]);
        }
        void apply(ll v, ll l, ll r, ll x) {
                if (x == 0) {
                        return;
                }
                get<3>(arr[v]) -= x;
                if (get<3>(arr[v]) < 0) {
                        get<3>(arr[v]) = 0;
                }
                if (r - l == 1) {
                        get<0>(arr[v]) = max(0ll, get<0>(arr[v]) - x);
                        return;
                }
                if (get<3>(arr[v]) == 0) {
                        get<0>(arr[v]) = 0;
                        get<1>(arr[v]) = r - l;
                        get<2>(arr[v]) = 0;
                        ml[v] = get<3>(arr[v * 2]);
                        mr[v] = get<3>(arr[v * 2 + 1]);
                        return;
                }
                while ((get<0>(arr[v]) - get<2>(arr[v])) * get<1>(arr[v]) <= x) {
                        ll mid = (l + r) / 2;
                        ll cmxl = 0, cmxr = 0;
                        if (get<0>(arr[v * 2]) == get<4>(arr[v])) {
                                cmxl = get<1>(arr[v * 2]);
                        }
                        if (get<0>(arr[v * 2 + 1]) == get<4>(arr[v])) {
                                cmxr = get<1>(arr[v * 2 + 1]);
                        }
                        apply(v * 2, l, mid, (get<0>(arr[v]) - get<2>(arr[v])) * cmxl + ml[v]);
                        apply(v * 2 + 1, mid, r, (get<0>(arr[v]) - get<2>(arr[v])) * cmxr + mr[v]);
                        ml[v] = 0;
                        mr[v] = 0;
                        x -= (get<0>(arr[v]) - get<2>(arr[v])) * get<1>(arr[v]);
                        pull(v);
                }
                ll cmxl = 0, cmxr = 0;
                if (get<0>(arr[v * 2]) == get<4>(arr[v])) {
                        cmxl = get<1>(arr[v * 2]);
                }
                if (get<0>(arr[v * 2 + 1]) == get<4>(arr[v])) {
                        cmxr = get<1>(arr[v * 2 + 1]);
                }
                ll fu = x / get<1>(arr[v]);
                x -= fu * get<1>(arr[v]);
                get<0>(arr[v]) -= fu;
                ml[v] += fu * cmxl;
                mr[v] += fu * cmxr;
                if (x != 0) {
                        get<2>(arr[v]) = get<0>(arr[v]) - 1;
                        get<1>(arr[v]) -= x;
                        if (get<0>(arr[v * 2]) == get<4>(arr[v])) {
                                ml[v] += min(get<1>(arr[v * 2]), x);
                                x -= min(get<1>(arr[v * 2]), x);
                        }
                        if (get<0>(arr[v * 2 + 1]) == get<4>(arr[v])) {
                                mr[v] += min(get<1>(arr[v * 2 + 1]), x);
                                x -= min(get<1>(arr[v * 2 + 1]), x);
                        }
                }
        }
        void push(ll v, ll l, ll r) {
                ll mid = (l + r) / 2;
                apply(v * 2, l, mid, ml[v]);
                apply(v * 2 + 1, mid, r, mr[v]);
                ml[v] = 0;
                mr[v] = 0;
        }
        ll getsum(ll l, ll r) {
                return getsum(1, 0, sz, l, r);
        }
        ll getsum(ll v, ll l, ll r, ll ql, ll qr) {
                if (ql <= l && r <= qr) {
                        return get<3>(arr[v]);
                }
                if (qr <= l || r <= ql) {
                        return 0;
                }
                ll mid = (l + r) / 2;
                push(v, l, r);
                return getsum(v * 2, l, mid, ql, qr) + getsum(v * 2 + 1, mid, r, ql, qr);
        }
        void set(ll x, ll y) {
                set(1, 0, sz, x, y);
        }
        void set(ll v, ll l, ll r, ll x, ll y) {
                if (x < l || r <= x) {
                        return;
                }
                if (r - l == 1) {
                        arr[v] = make_tuple(y, 1, 0, y, y);
                        return;
                }
                ll mid = (l + r) / 2;
                push(v, l, r);
                set(v * 2, l, mid, x, y);
                set(v * 2 + 1, mid, r, x, y);
                pull(v);
        }
        void pullall(ll l, ll r) {
                pullall(1, 0, sz, l, r);
        }
        void pullall(ll v, ll l, ll r, ll ql, ll qr) {
                if (ql <= l && r <= qr) {
                        return;
                }
                if (qr <= l || r <= ql) {
                        return;
                }
                ll mid = (l + r) / 2;
                pullall(v * 2, l, mid, ql, qr);
                pullall(v * 2 + 1, mid, r, ql, qr);
                pull(v);
        }
        vector<tuple<ll, ll, ll>> sub(ll l, ll r) {
                vector<tuple<ll, ll, ll>> ans;
                sub(1, 0, sz, l, r, ans);
                return ans;
        }
        void sub(ll v, ll l, ll r, ll ql, ll qr, vector<tuple<ll, ll, ll>>& ans) {
                if (ql <= l && r <= qr) {
                        ans.emplace_back(v, l, r);
                        return;
                }
                if (qr <= l || r <= ql) {
                        return;
                }
                ll mid = (l + r) / 2;
                push(v, l, r);
                sub(v * 2, l, mid, ql, qr, ans);
                sub(v * 2 + 1, mid, r, ql, qr, ans);
        }
        ll gs(ll v) {
                return get<3>(arr[v]);
        }
        ll gmx(ll v) {
                return get<0>(arr[v]);
        }
        ll gcmx(ll v) {
                return get<1>(arr[v]);
        }
        tuple<ll, ll, ll, ll, ll> gett(ll v) {
                return arr[v];
        }
};

segment_tree st;

void initialise(int n, int q, int h[]) {
        st = segment_tree(n);
        for (ll i = 0; i < n; i++) {
                st.set(i, h[i + 1]);
        }
}

void cut(int l, int r, int kq) {
        ll k = kq;
        l--;
        vector<tuple<ll, ll, ll>> as = st.sub(l, r);
        if (k >= st.getsum(l, r)) {
                for (auto[v, l, r] : as) {
                        st.apply(v, l, r, st.gs(v));
                }
                st.pullall(l, r);
                return;
        }
        while (k) {
                tuple<ll, ll, ll, ll, ll> now(0, 0, 0, 0, 0);
                for (auto[v, l, r] : as) {
                        now = merge(now, st.gett(v));
                }
                if ((get<0>(now) - get<2>(now)) * get<1>(now) <= k) {
                        for (auto[v, l, r] : as) {
                                if (st.gmx(v) == get<4>(now)) {
                                        st.apply(v, l, r, st.gcmx(v) * (get<0>(now) - get<2>(now)));
                                }
                        }
                        k -= (get<0>(now) - get<2>(now)) * get<1>(now);
                } else {
                        ll fc = k / get<1>(now);
                        k -= get<1>(now) * fc;
                        for (auto[v, l, r] : as) {
                                if (st.gmx(v) == get<4>(now)) {
                                        st.apply(v, l, r, st.gcmx(v) * fc + min(k, st.gcmx(v)));
                                        k -= min(k, st.gcmx(v));
                                }
                        }
                }
        }
        st.pullall(l, r);
}

void magic(int i, int x) {
        i--;
        st.set(i, x);
}

ll inspect(int l, int r) {
        l--;
        return st.getsum(l, r);
}

#ifdef LOCAL
int main() {
    int N, Q;
    cin >> N >> Q;

    int h[N + 1];

    for (int i = 1; i <= N; ++i) cin >> h[i];

    initialise(N, Q, h);

    for (int i = 1; i <= Q; ++i) {
        int t;
        cin >> t;

        if (t == 1) {
            int l, r, k;
            cin >> l >> r >> k;
            cut(l, r, k);
        } else if (t == 2) {
            int i, x;
            cin >> i >> x;
            magic(i, x);
        } else {
            int l, r;
            cin >> l >> r;
            cout << inspect(l, r) << '\n';
        }
    }
    return 0;
}
#endif
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...