Submission #804352

#TimeUsernameProblemLanguageResultExecution timeMemory
804352ono_de206Distributing Candies (IOI21_candies)C++17
11 / 100
158 ms8896 KiB
#include "candies.h" #include<bits/stdc++.h> using namespace std; #define in insert #define all(x) x.begin(),x.end() #define pb push_back #define eb emplace_back #define ff first #define ss second // #define int long long typedef long long ll; typedef vector<int> vi; typedef set<int> si; typedef multiset<int> msi; typedef pair<int, int> pii; typedef vector<pii> vpii; const int mxn = 2e5 + 10; const long long inf = 1e18 + 10; struct segTreeBeats { struct node { long long mx1, mx2, mn1, mn2, lz, mxc, mnc, sum; node(int x = 0) : mx1(x), mx2(-inf), mn1(x), mn2(inf), lz(0), mxc(1), mnc(1), sum(x) {} friend node operator+(const node &a, const node &b) { node ret; ret.sum = a.sum + b.sum; ret.mx1 = max(a.mx1, b.mx1); ret.mn1 = min(a.mn1, b.mn1); if(b.mx1 == a.mx1) { ret.mxc = a.mxc + b.mxc; ret.mx2 = max(a.mx2, b.mx2); } else { if(b.mx1 > a.mx1) { ret.mxc = b.mxc; ret.mx2 = max(a.mx1, b.mx2); } else { ret.mxc = a.mxc; ret.mx2 = max(a.mx2, b.mx1); } } if(b.mn1 == a.mn1) { ret.mnc = a.mnc + b.mnc; ret.mn2 = min(a.mn2, b.mn2); } else { if(b.mn1 < a.mn1) { ret.mnc = b.mnc; ret.mn2 = min(a.mn1, b.mn2); } else { ret.mnc = a.mnc; ret.mn2 = min(a.mn2, b.mn1); } } return ret; } }; vector<node> d; vector<int> c; int n; void build(int l, int r, int i) { if(l == r) { d[i] = c[l]; return; } int m = (l + r) / 2; build(l, m, i * 2); build(m + 1, r, i * 2 + 1); d[i] = d[i * 2] + d[i * 2 + 1]; } segTreeBeats(int _n, vector<int> _c) { n = _n; c = _c; d.resize(n * 4 + 10); build(0, n - 1, 1); } void ADD(int i, int l, int r, long long v) { d[i].sum += 1LL * v * (r - l + 1); d[i].mx1 += v; d[i].mn1 += v; if(d[i].mx2 != -inf) d[i].mx2 += v; if(d[i].mn2 != inf) d[i].mn2 += v; d[i].lz += v; } void MNN(int i, int l, int r, long long v) { if(v >= d[i].mx1) return; d[i].sum -= d[i].mxc * d[i].mx1; d[i].mx1 = v; d[i].sum += d[i].mxc * d[i].mx1; if(l == r) d[i].mn1 = d[i].mx1; else { if(d[i].mn1 >= v) d[i].mn1 = v; else if(d[i].mn2 > v) d[i].mn2 = v; } } void MXX(int i, int l, int r, long long v) { if(v <= d[i].mn1) return; d[i].sum -= d[i].mnc * d[i].mn1; d[i].mn1 = v; d[i].sum += d[i].mnc * d[i].mn1; if(l == r) d[i].mx1 = d[i].mn1; else { if(d[i].mx1 <= v) d[i].mx1 = v; else if(d[i].mn2 < v) d[i].mx2 = v; } } void pro(int i, int l, int r) { if(l == r) return; int m = (l + r) / 2; ADD(i * 2, l, m, d[i].lz); ADD(i * 2 + 1, m + 1, r, d[i].lz); d[i].lz = 0; MNN(i * 2, l, m, d[i].mx1); MNN(i * 2 + 1, m + 1, r, d[i].mx1); MXX(i * 2, l, m, d[i].mn1); MXX(i * 2 + 1, m + 1, r, d[i].mn1); } void Add(int l, int r, int i, int x, int y, long long v) { if(l > y || r < x) return; if(l >= x && r <= y) { ADD(i, l, r, v); return; } pro(i, l, r); int m = (l + r) / 2; Add(l, m, i * 2, x, y, v); Add(m + 1, r, i * 2 + 1, x, y, v); d[i] = d[i * 2] + d[i * 2 + 1]; } void Mnn(int l, int r, int i, int x, int y, long long v) { if(l > y || r < x || d[i].mx1 <= v) return; if(l >= x && r <= y && d[i].mx2 < v) { MNN(i, l, r, v); return; } pro(i, l, r); int m = (l + r) / 2; Mnn(l, m, i * 2, x, y, v); Mnn(m + 1, r, i * 2 + 1, x, y, v); d[i] = d[i * 2] + d[i * 2 + 1]; } void Mxx(int l, int r, int i, int x, int y, long long v) { if(l > y || r < x || d[i].mn1 >= v) return; if(l >= x && r <= y && d[i].mn2 > v) { MXX(i, l, r, v); return; } pro(i, l, r); int m = (l + r) / 2; Mxx(l, m, i * 2, x, y, v); Mxx(m + 1, r, i * 2 + 1, x, y, v); d[i] = d[i * 2] + d[i * 2 + 1]; } void add(int l, int r, long long v) { Add(0, n - 1, 1, l, r, v); } void mnn(int l, int r, long long v) { Mnn(0, n - 1, 1, l, r, v); } void mxx(int l, int r, long long v) { Mxx(0, n - 1, 1, l, r, v); } long long GetSum(int l, int r, int i, int x, int y) { if(l >= x && r <= y) return d[i].sum; if(l > y || r < x) return 0LL; pro(i, l, r); int m = (l + r) / 2; return GetSum(l, m, i * 2, x, y) + GetSum(m + 1, r, i * 2 + 1, x, y); } long long getSum(int l, int r) { return GetSum(0, n - 1, 1, l, r); } }; vi distribute_candies(vi c, vi l, vi r, vi v) { int n = c.size(), q = l.size(); vector<long long> ret(n); if(n * q <= 2000 * 2000) { for(int i = 0; i < q; i++) { for(int j = l[i]; j <= r[i]; j++) { ret[j] += v[i]; ret[j] = min(1ll * c[j], max(0ll, ret[j])); } } return vector<int>(all(ret)); } if(*min_element(all(v)) >= 0) { for(int i = 0; i < q; i++) { ret[l[i]] += v[i]; if(r[i] + 1 < n) ret[r[i] + 1] -= v[i]; } for(int i = 1; i < n; i++) { ret[i] += ret[i - 1]; } for(int i = 0; i < n; i++) { ret[i] = min(1ll * c[i], max(0ll, ret[i])); } return vector<int>(all(ret)); } if(*max_element(all(c)) == *min_element(all(c))) { segTreeBeats st(n, c); for(int i = 0; i < q; i++) { st.add(l[i], r[i], v[i]); st.mnn(l[i], r[i], c[0]); st.mxx(l[i], r[i], 0); } for(int i = 0; i < n; i++) { ret[i] = st.getSum(i, i); } return vector<int>(all(ret)); } exit(1); return vector<int>(all(ret)); }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...