Submission #905116

#TimeUsernameProblemLanguageResultExecution timeMemory
905116GrindMachineCake 3 (JOI19_cake3)C++17
100 / 100
1206 ms109780 KiB
#include <bits/stdc++.h> #include <ext/pb_ds/assoc_container.hpp> #include <ext/pb_ds/tree_policy.hpp> using namespace std; using namespace __gnu_pbds; template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>; typedef long long int ll; typedef long double ld; typedef pair<int,int> pii; typedef pair<ll,ll> pll; #define fastio ios_base::sync_with_stdio(false); cin.tie(NULL) #define pb push_back #define endl '\n' #define sz(a) (int)a.size() #define setbits(x) __builtin_popcountll(x) #define ff first #define ss second #define conts continue #define ceil2(x,y) ((x+y-1)/(y)) #define all(a) a.begin(), a.end() #define rall(a) a.rbegin(), a.rend() #define yes cout << "Yes" << endl #define no cout << "No" << endl #define rep(i,n) for(int i = 0; i < n; ++i) #define rep1(i,n) for(int i = 1; i <= n; ++i) #define rev(i,s,e) for(int i = s; i >= e; --i) #define trav(i,a) for(auto &i : a) template<typename T> void amin(T &a, T b) { a = min(a,b); } template<typename T> void amax(T &a, T b) { a = max(a,b); } #ifdef LOCAL #include "debug.h" #else #define debug(x) 42 #endif /* refs: http://rachitiitr.blogspot.com/2017/06/wavelet-trees-wavelet-trees-editorial.html (wavelet tree tutorial) for a given set of m cakes, cost = sum-2*(mx_depth-mn_depth) sort all cakes by depth mn_depth cake = i mx_depth cake = j for a given (i,j), the 2nd part of the cost is fixed so we just need to maximize sum pick m largest guys from a[i..j] fix i, iterate over all j >= i in inc ord, maintain the sum of the m largest guys efficiently (using a priority_queue) how to optimize? key observation: let opt(i) = min j s.t f(i,j) gives the max value opt(i) <= opt(i+1) proof? idk, but it seems intuitive (verified with asserts and it worked on the first 2 subtasks) once we get this observation, it's just d&c (similar to d&c dp, with additional parameters optl and optr) to quickly find f(i,j), we can use a wavelet tree */ const int MOD = 1e9 + 7; const int N = 2e5 + 5; const int inf1 = int(1e9) + 5; const ll inf2 = ll(1e18) + 5; vector<int> oa(N); struct wavelet_tree{ struct node{ node *l, *r; vector<int> b; vector<ll> pref; }; node* root; int siz; wavelet_tree(){ } node* build(vector<int> &a, vector<int> inds, int l, int r){ if(l > r) return NULL; int mid = (l+r) >> 1; vector<int> b(sz(inds)+1); vector<int> la,ra; rep(i,sz(inds)){ int x = a[inds[i]]; b[i+1] = b[i]+(x <= mid); if(x <= mid){ la.pb(inds[i]); } else{ ra.pb(inds[i]); } } node* curr = new node(); curr->b = b; vector<ll> pref; pref.pb(0); trav(i,inds){ pref.pb(pref.back()+oa[i]); } curr->pref = pref; if(l != r){ curr->l = build(a,la,l,mid); curr->r = build(a,ra,mid+1,r); } return curr; } void build(vector<int> &a, int mx_siz){ siz = mx_siz; vector<int> inds; rep(i,sz(a)) inds.pb(i); root = build(a,inds,0,siz); } int kth(node* u, int lx, int rx, int l, int r, int k){ if(lx == rx){ return lx; } int mid = (lx+rx) >> 1; auto &b = u->b; int cnt = b[r]-b[l-1]; if(k <= cnt){ return kth(u->l,lx,mid,b[l-1]+1,b[r],k); } else{ return kth(u->r,mid+1,rx,l-b[l-1],r-b[r],k-cnt); } } int kth(int l, int r, int k){ return kth(root,0,siz,l+1,r+1,k); } pll get(node* u, int lx, int rx, int vl, int vr, int l, int r){ if(lx > vr or rx < vl) return {0,0}; if(!u) return {0,0}; if(lx > rx) return {0,0}; if(lx >= vl and rx <= vr){ return {r-l+1,u->pref[r]-u->pref[l-1]}; } int mid = (lx+rx) >> 1; auto &b = u->b; pll p1 = get(u->l,lx,mid,vl,vr,b[l-1]+1,b[r]); pll p2 = get(u->r,mid+1,rx,vl,vr,l-b[l-1],r-b[r]); return {p1.ff+p2.ff,p1.ss+p2.ss}; } pll get(int l, int r, int k){ return get(root,0,siz,k,siz,l+1,r+1); } }; vector<pii> a(N); vector<int> c; wavelet_tree wt; ll ans = -inf2; int n,m; ll get_cost(int l, int r){ int len = r-l+1; if(len < m) return -inf2; int k = len-m+1; int kth = wt.kth(l,r,k); auto [cnt,sum] = wt.get(l,r,kth); assert(cnt >= m); ll extra = cnt-m; sum -= extra*c[kth]; sum -= 2*(a[r].ff-a[l].ff); return sum; } void go(int l, int r, int optl, int optr){ if(l > r) return; int mid = (l+r) >> 1; ll best = -inf2; int opt_mid = optr; for(int j = optl; j <= optr; ++j){ if(mid <= j){ ll cost = get_cost(mid,j); if(cost > best){ best = cost; opt_mid = j; } } } amax(ans,best); go(l,mid-1,optl,opt_mid); go(mid+1,r,opt_mid,optr); } void solve(int test_case) { cin >> n >> m; rep(i,n) cin >> a[i].ss >> a[i].ff; sort(a.begin(),a.begin()+n); vector<int> b(n); rep(i,n) oa[i] = b[i] = a[i].ss; c = b; sort(all(c)); c.resize(unique(all(c))-c.begin()); int siz = sz(c); rep(i,n) b[i] = lower_bound(all(c),b[i])-c.begin(); rep(i,n) a[i].ss = b[i]; wt.build(b,siz-1); go(0,n-1,0,n-1); cout << ans << endl; } int main() { fastio; int t = 1; // cin >> t; rep1(i, t) { solve(i); } return 0; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...