Submission #1126341

#TimeUsernameProblemLanguageResultExecution timeMemory
1126341vladiliusTricks of the Trade (CEOI23_trade)C++20
10 / 100
2338 ms270704 KiB
#include <bits/stdc++.h> using namespace std; using ll = long long; using pii = pair<int, int>; using pib = pair<int, bool>; using pli = pair<ll, int>; #define pb push_back #define ff first #define ss second const ll infm = -1e18; struct PST{ struct node{ node *l, *r; ll s; int c; node(node *ls, node *rs){ l = ls; r = rs; s = c = 0; if (l){ s += l -> s; c += l -> c; } if (r){ s += r -> s; c += r -> c; } } node(ll s1, int c1){ l = r = 0; s = s1; c = c1; } }; vector<node*> root; int n, cc; PST(int ns){ n = ns; root.resize(n + 1); root[0] = build(1, n); cc = 0; } node* build(int tl, int tr){ if (tl == tr) return new node(0, 0); int tm = (tl + tr) / 2; return new node(build(tl, tm), build(tm + 1, tr)); } node* upd(node *v, int tl, int tr, int& p, int& x){ if (tl == tr) return new node(v -> s + x, v -> c + 1); int tm = (tl + tr) / 2; if (p <= tm){ if (!(v -> l)) v -> l = new node(0, 0); return new node(upd(v -> l, tl, tm, p, x), v -> r); } else { if (!(v -> r)) v -> r = new node(0, 0); return new node(v -> l, upd(v -> r, tm + 1, tr, p, x)); } } void upd(int p, int x){ cc++; root[cc] = upd(root[cc - 1], 1, n, p, x); } int find(node *v1, node *v2, int tl, int tr, int k){ if (tl == tr) return tl; int tm = (tl + tr) / 2, rf = (v2 -> r -> c) - (v1 -> r -> c); if (k <= rf) return find(v1 -> r, v2 -> r, tm + 1, tr, k); return find(v1 -> l, v2 -> l, tl, tm, k - rf); } pli sum(node *v, int tl, int tr, int& l, int& r){ if (l > tr || r < tl) return {0, 0}; if (l <= tl && tr <= r) return {v -> s, v -> c}; int tm = (tl + tr) / 2; pli x = sum(v -> l, tl, tm, l, r), y = sum(v -> r, tm + 1, tr, l, r); return {x.ff + y.ff, x.ss + y.ss}; } pli sum(int v, int l, int r){ return sum(root[v], 1, n, l, r); } ll get(int l, int r, int k){ int t = find(root[l - 1], root[r], 1, n, k); return sum(r, t, n).ff - sum(l - 1, t, n).ff; } }; struct ST{ vector<pib> t; vector<int> a; vector<bool> out; vector<int> :: iterator it; int n, k; ST(int ns, int ks, vector<pii> as){ n = ns; k = ks; a.resize(n + 1); for (int i = 1; i <= n; i++){ a[i] = as[i].ff; } out.resize(n + 1); t.resize(4 * n); } void upd(int v, int tl, int tr, int& x, bool& i){ if (tl == tr){ t[v] = {i, i}; return; } int tm = (tl + tr) / 2, vv = 2 * v; if (x <= tm){ upd(vv, tl, tm, x, i); } else { upd(vv + 1, tm + 1, tr, x, i); } t[v].ff = t[vv].ff + t[vv + 1].ff; t[v].ss = max(t[vv].ss, t[vv + 1].ss); } void upd(int x, bool i){ upd(1, 1, n, x, i); } int find(int v, int tl, int tr, int x){ if (tl == tr) return tl; int tm = (tl + tr) / 2, vv = 2 * v; if (x <= t[vv + 1].ff){ return find(vv + 1, tm + 1, tr, x); } return find(vv, tl, tm, x - t[vv + 1].ff); } void f(int v, int tl, int tr, int& l, int& r){ if (l > tr || r < tl || !t[v].ss) return; if (tl == tr){ out[tl] = 1; t[v].ss = 0; return; } int tm = (tl + tr) / 2, vv = 2 * v; f(vv, tl, tm, l, r); f(vv + 1, tm + 1, tr, l, r); t[v].ss = max(t[vv].ss, t[vv + 1].ss); } void find(){ int t = find(1, 1, n, k); it = lower_bound(a.begin() + 1, a.end(), a[t]); t = (int) (it - a.begin()); f(1, 1, n, t, n); } }; int main(){ ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0); int n, k; cin>>n>>k; vector<int> a(n + 1); vector<ll> p(n + 1); for (int i = 1; i <= n; i++){ cin>>a[i]; p[i] = p[i - 1] + a[i]; } vector<int> b(n + 1); vector<pii> all = {{0, 0}}; for (int i = 1; i <= n; i++){ cin>>b[i]; all.pb({b[i], i}); } sort(all.begin(), all.end()); vector<int> pos(n + 1); for (int i = 1; i <= n; i++){ pos[all[i].ss] = i; } PST T(n); for (int i = 1; i <= n; i++){ T.upd(pos[i], b[i]); } auto f = [&](int l, int r){ if ((r - l + 1) < k) return infm; return T.get(l, r, k) - (p[r] - p[l - 1]); }; vector<pli> opt(n + 1); ll out = infm; function<void(int, int, int, int)> solve = [&](int l, int r, int l1, int r1){ if (l > r) return; int m = (l + r) / 2; opt[m] = {infm, m}; for (int i = r1; i >= max(m, l1); i--){ ll g = f(m, i); if (opt[m].ff <= g){ opt[m] = {g, i}; } } out = max(out, opt[m].ff); solve(l, m - 1, l1, opt[m].ss); solve(m + 1, r, opt[m].ss, r1); }; solve(1, n, 1, n); vector<int> d; for (int i = 1; i <= n; i++){ if (opt[i].ff == out){ d.pb(i); } } int l = 1, r = 1; ST F(n, k, all); for (int j = 0; j < d.size(); j++){ int i = d[j]; while (r < opt[i].ss){ F.upd(pos[r++], 1); } while (l < i){ F.upd(pos[l++], 0); } if (i < r && f(i, r - 1) == out){ F.find(); } int R = (j == (d.size() - 1)) ? n : opt[d[j + 1]].ss; while (r <= R){ F.upd(pos[r], 1); if (f(i, r) == out){ F.find(); } r++; } } cout<<out<<"\n"; for (int i = 1; i <= n; i++){ cout<<F.out[pos[i]]; } cout<<"\n"; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...