Submission #996845

#TimeUsernameProblemLanguageResultExecution timeMemory
996845dimashhhDistributing Candies (IOI21_candies)C++17
67 / 100
1575 ms61592 KiB
#include <bits/stdc++.h>
#include "candies.h"
using namespace std;

using namespace std;
const int N = 3e5 + 1,B = 448;
typedef long long ll;

ll a[N],C;
struct func{
    ll x1 = 0,x2 = C,y1 = 0,y2 = C;
}d[N / B + 10];
ll get(func f,ll x){
    if(x <= f.x1) return f.y1;
    if(x >= f.x2) return f.y2;
    return x - f.x1 + f.y1;
}
int get_bl(int i){
    return i / B;
}
void handle(int bl,int l,int r,int val){
    for(int i = bl * B;i < (bl + 1) * B;i++){
        a[i] = get(d[bl],a[i]);
        if(i >= l && i <= r){
            a[i] += val;
            a[i] = max(0ll,min(a[i],(ll)C));
        }
    }
    d[bl].x1 = d[bl].y1 = 0;
    d[bl].x2 = d[bl].y2 = C;
}
func add1(func t,int val){
    if(val > 0){
        if(t.y2 + val <= C){
            t.y2 += val;
            t.y1 += val;
            return t;
        }
        if(t.y1 + val >= C){
            t.y2 = t.y1 = t.x1 = C;
            t.x2 = C + 1;
            return t;
        }
        t.x2 -= (t.y2 - C + val);
        t.y1 += val;
        t.y2 = C;
        return t;
    }else{
        if(t.y1 + val >= 0){
            t.y1 += val;
            t.y2 += val;
            return  t;
        }
        if(t.y2 + val <= 0){
            t.x1 = C;
            t.y1 = 0;
            t.x2 = C + 1;
            t.y2 = 0;
            return t;
        }
        t.y2 += val;
        t.x1 += (-t.y1 - val);
        t.y1 = 0;
        return t;
    }
}
ll res[N];
int n,q;

vector<int> solve(vector<int> c, vector<int> l,vector<int> r, vector<int> v){
    bool ok = true;
    n = (int)c.size();
    q = (int)l.size();
    for(auto j:v){
        if(j < 0)ok = false;
    }
    if(n <= 2000 || ok){
        for(int i = 1;i <= n;i++){
            a[i] = c[i - 1];   
        }
        if(ok){
            for(int i = 0;i < q;i++){
                l[i]++;
                r[i]++;
                res[l[i]] += v[i];
                res[r[i] + 1] -= v[i];
            }
            vector<int> ret;
            for(int i = 1;i <= n;i++){
                res[i] += res[i - 1];
                ret.push_back((int)min((ll)a[i],res[i]));
            }
            return ret;
        }
        for(int i = 0;i < q;i++){
            l[i]++;
            r[i]++;
        }
        vector<int> ret;
        for(int i = 1;i <= n;i++){
            ll u = a[i],d = 0,cur = 0;
            for(int j = 0;j < q;j++){
                if(l[j] > i || r[j] < i) continue;
                cur += v[j];
                if(cur >= u){
                    u = cur;
                    d = u - a[i];
                }
                if(cur <= d){
                    d = cur;
                    u = d + a[i];
                }
            }
            ret.push_back(cur - d);
        }
        return ret;
    }
    vector<int> ret;
    C = c[0];
    ret.resize(n);
    for(int i = 0;i < q;i++){
        int x = get_bl(l[i]),y = get_bl(r[i]);
        handle(x,l[i],r[i],v[i]);
        if(x != y) handle(y,l[i],r[i],v[i]);
        for(int j = x + 1;j < y;j++){
            d[j] = add1(d[j],v[i]);
        }
    }
    for(int i = 0;i < (int)c.size();i++){
        ret[i] = (int)get(d[get_bl(i)],a[i]);
    }
    return ret;
}
vector<int> add[N],del[N];
ll t[N * 4],mod[N * 4];
pair<ll,int> mx[N * 4],mn[N * 4];
void build(int v = 1,int tl = 0,int tr = q){
    if(tl == tr){
        mx[v] = mn[v] = {0,tl};
    }else{
        int tm = (tl + tr) >> 1;
        build(v + v,tl,tm);
        build(v + v + 1,tm + 1,tr);
        mx[v] = mn[v] = mx[v + v + 1];
    }
}
void merge(int v){
    t[v] = t[v + v] + t[v + v + 1];
    mx[v] = max(mx[v + v],mx[v + v + 1]);
    if(mn[v + v].first != mn[v + v + 1].first){
        mn[v] = min(mn[v + v],mn[v + v +1]);
    }else{
        mn[v] = max(mn[v + v],mn[v + v + 1]);
    }
}
void inc(int v,int val){
    t[v] += val;
    mn[v].first += val;
    mx[v].first += val;
    mod[v] += val;
}
void push(int v){
    if(mod[v]){
        inc(v + v,mod[v]);
        inc(v + v + 1,mod[v]);
        mod[v] = 0;
    }
}
void upd(int l,int r,int val,int v = 1,int tl = 0,int tr = q){
    if(l > r || tl > r || l > tr)return;
    if(tl >= l && tr <= r){
        inc(v,val);
    }else{
        push(v);
        int tm = (tl + tr) >> 1;
        upd(l,r,val,v+v,tl,tm);
        upd(l,r,val,v+v+1,tm+1,tr);
        merge(v);
    }
}
const ll inf = 1e18;
pair<ll,ll> get_max(int l,int r,int v = 1,int tl =0 ,int tr = q){
    if(l > r || tl > r || l > tr) return {-inf,-inf};
    if(tl >= l && tr <= r) return mx[v];
    push(v);
    int tm = (tl + tr) >> 1;
    return max(get_max(l,r,v+v,tl,tm),get_max(l,r,v+v+1,tm+1,tr));
}
pair<ll,ll> get_min(int l,int r,int v = 1,int tl =0 ,int tr = q){
    if(l > r || tl > r || l > tr) return {inf,inf};
    if(tl >= l && tr <= r) return mn[v];
    push(v);
    int tm = (tl + tr) >> 1;
    pair<ll,ll> L = get_min(l,r,v+v,tl,tm),R = get_min(l,r,v+v+1,tm+1,tr);
    if(L.first != R.first){
        return min(L,R);
    }
    return max(L,R);
}
ll get_sum(int l,int r,int v = 1,int tl = 0,int tr = q){
    if(l > r || tl > r || l > tr) return 0;
    if(tl >= l && tr <= r) return t[v];
    push(v);
    int tm = (tl + tr) >> 1;
    return get_sum(l,r,v+v,tl,tm)+get_sum(l,r,v+v+1,tm+1,tr);
}

bool ok(int i,int c){
    return (get_max(i,q).first - get_min(i,q).first > c);
}
vector<int> distribute_candies(vector<int> c, vector<int> l,vector<int> r, vector<int> v){
    n = (int)c.size();
    q = (int)l.size();
    bool OK = false;
    bool bad = true;
    for(int i = 0;i < q;i++){
        if(l[i] != 0 || r[i] != n -1)bad = false;
    }
    for(int i = 1;i < (int)c.size();i++){
        if(c[i] != c[i - 1]){
            OK =1;
        }
    }
    if(!OK && !bad){
        return solve(c,l,r,v);
    }
    for(int i:c){
        if(i < 0) OK = false;
    }
    if(OK && !bad){
        return solve(c,l,r,v);
    }
    build();
    for(int i = 0;i < q;i++){
        add[l[i]].push_back(i);
        del[r[i]].push_back(i);
    }
    vector<int> ret;
    for(int i = 0;i < n;i++){
        for(int j:add[i]){
            upd(j + 1,q,v[j]);
        }
        ll all = get_sum(q,q);
        if(mx[1].first - mn[1].first <= c[i]){
            ret.push_back(get_sum(q,q)-mn[1].first);
        }else{
            int L = 0,R = q+1;
            while(R - L > 1){
                int mid = (L + R) >> 1;
                if(ok(mid,c[i])){
                    L = mid;
                }else R = mid;
            }
            ll val = get_sum(L,L);
            pair<ll,ll> f = get_max(L,q),s = get_min(L,q);
            if(abs(f.first - val) > c[i]){
                ret.push_back(c[i] - (f.first - all));
            }else{
                ret.push_back(all - s.first);
            }
        }
        for(int j:del[i]){
            upd(j + 1,q,-v[j]);
        }
    }
    return ret;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...