Submission #1128621

#TimeUsernameProblemLanguageResultExecution timeMemory
1128621vladiliusCake 3 (JOI19_cake3)C++20
100 / 100
1138 ms17028 KiB
#include <bits/stdc++.h> using namespace std; using ll = long long; using pii = pair<int, int>; using pli = pair<ll, int>; #define pb push_back #define ff first #define ss second const ll infm = -1e18; struct ST{ vector<pli> t; int n, m; ST(int ns, int ms){ n = ns; m = ms; t.resize(4 * n); } void upd(int v, int tl, int tr, int& p, int& x){ if (tl == tr){ t[v] = {x, (x > 0)}; return; } int tm = (tl + tr) / 2, vv = 2 * v; if (p <= tm){ upd(vv, tl, tm, p, x); } else { upd(vv + 1, tm + 1, tr, p, x); } t[v].ff = t[vv].ff + t[vv + 1].ff; t[v].ss = t[vv].ss + t[vv + 1].ss; } void upd(int p, int x){ upd(1, 1, n, p, x); } ll get(int v, int tl, int tr, int k){ if (!k) return 0; if (t[v].ss <= k) return t[v].ff; int tm = (tl + tr) / 2, vv = 2 * v; if (t[vv + 1].ss >= k){ return get(vv + 1, tm + 1, tr, k); } return get(vv, tl, tm, k - t[vv + 1].ss) + t[vv + 1].ff; } ll get(){ return get(1, 1, n, m - 2); } }; int main(){ ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0); int n, M; cin>>n>>M; vector<pii> all; for (int i = 1; i <= n; i++){ int x, y; cin>>x>>y; all.pb({y, x}); } sort(all.begin(), all.end()); vector<int> a(n + 1), b(n + 1); for (int i = 1; i <= n; i++){ a[i] = all[i - 1].ss; b[i] = all[i - 1].ff; } all.clear(); for (int i = 1; i <= n; i++){ all.pb({a[i], i}); } sort(all.begin(), all.end()); vector<int> p(n + 1); for (int i = 0; i < n; i++){ p[all[i].ss] = i + 1; } ST T(n, M); auto f = [&](int l, int r){ if ((r - l + 1) < M) return infm; return a[l] + a[r] - 2 * (b[r] - b[l]) + T.get(); }; auto f1 = [&](int l, int r){ if ((r - l + 1) < M) return infm; ll out = a[l] + a[r] - 2 * (b[r] - b[l]); multiset<int> st; for (int i = l + 1; i < r; i++){ st.insert(a[i]); } int k = M - 2; while (k--){ auto it = prev(st.end()); out += *it; st.erase(it); } return out; }; ll out = infm; int tl = 1, tr = 0; function<void(int, int, int, int)> solve = [&](int l, int r, int l1, int r1){ if (l > r) return; int m = (l + r) / 2; pli opt = {infm, 0}; for (int i = max(l1, m + M - 1); i <= r1; i++){ while (tr < (i - 1)){ tr++; T.upd(p[tr], a[tr]); } while (tl > (m + 1)){ tl--; T.upd(p[tl], a[tl]); } while (tr > (i - 1)){ T.upd(p[tr], 0); tr--; } while (tl < (m + 1)){ T.upd(p[tl], 0); tl++; } opt = max(opt, {f(m, i), i}); } out = max(out, opt.ff); solve(l, m - 1, l1, opt.ss); solve(m + 1, r, opt.ss, r1); }; solve(1, n - M + 1, 1, n); cout<<out<<"\n"; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...