Submission #424096

#TimeUsernameProblemLanguageResultExecution timeMemory
424096Mamnoon_SiamCake 3 (JOI19_cake3)C++17
100 / 100
745 ms99616 KiB
#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using ii = pair<int, int>;
using vi = vector<int>;
#define all(v) begin(v), end(v)
#define sz(v) (int)(v).size()
#define fi first
#define se second
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...)
#endif

const int N = 2e5 + 5;

struct node {
  ll sum;
  int cnt, l, r;
  node () {
    cnt = 0, sum = 0;
    l = 0, r = 0; // careful
  }
};
node t[N * 19];
int ptr = 0;
int n, m, root[N];
ll V[N], C[N];

void update(int& u, int b, int e, int pos, ll val) {
  t[++ptr] = t[u];
  t[u = ptr].sum += val;
  t[u].cnt++;
  if(b == e) return;
  int mid = (b + e) >> 1;
  if(pos <= mid)
    update(t[u].l, b, mid, pos, val);
  else
    update(t[u].r, mid+1, e, pos, val);
  // pull? no
}

ll largest_ksum(int u, int v, int b, int e, int k) { // [v] \ [u]
  // assumption: v.cnt - u.cnt >= m
  if(t[v].cnt - t[u].cnt == k) return t[v].sum - t[u].sum;
  // assumption:
  //    * v.cnt - u.cnt > m
  //    * b < e
  // if left half has >= m, then just return that result
  int mid = (b + e) >> 1;
  if(t[t[v].l].cnt - t[t[u].l].cnt >= k)
    return largest_ksum(t[u].l, t[v].l, b, mid, k);
  // otherwise left half has < m, take  that
  // and take the rest from right half
  else {
    ll ret = t[t[v].l].sum - t[t[u].l].sum;
    int took = t[t[v].l].cnt - t[t[u].l].cnt;
    return ret + largest_ksum(t[u].r, t[v].r, mid+1, e, k - took);
  }
}

ll cost(int l, int r) {
  return largest_ksum(root[l], root[r-1], 1, n, m-2) + V[l] + V[r] - 2*(C[r] - C[l]);
  // return largest_ksum(root[l-1], root[r], 1, n, m) - 2LL * (C[r] - C[l]);
}

ll ans = LLONG_MIN;

void solve(int l, int r, int b, int e) { // solve for [l, r] in search space [b, e]
  if(l > r) return;
  pair<ll,int> opt(LLONG_MIN, -1);
  int mid = (l + r) >> 1;
  for(int i = b; i <= min(e, mid-m+1); ++i) {
    opt = max(opt, {cost(i, mid), i});
  }
  assert(~opt.se);
  ans = max(ans, opt.fi);
  solve(l, mid-1, b, opt.se);
  solve(mid+1, r, opt.se, e);
}

int main(int argc, char const *argv[])
{
  cin.sync_with_stdio(0); cin.tie(0);
  cin.exceptions(cin.failbit);
#ifdef LOCAL
  freopen("in", "r", stdin);
#endif
  cin >> n >> m;
  for(int i = 1; i <= n; ++i) {
    cin >> V[i] >> C[i];
  }
  vi p(n); iota(all(p), 1);
  sort(all(p), [](int i,int j){ return C[i] < C[j]; });
  {
    vi VV(n), CC(n);
    for(int i = 0; i < n; ++i) {
      VV[i] = (int)V[p[i]];
      CC[i] = (int)C[p[i]];
    }
    for(int i = 1; i <= n; ++i) {
      V[i] = VV[i-1];
      C[i] = CC[i-1];
    }
  }
  debug(vi(C+1, C+1+n));
  debug(vi(V+1, V+1+n));
  vi v_ord(n); iota(all(v_ord), 1);
  sort(all(v_ord), [](int i, int j){ return V[i] > V[j]; });
  vi v_rank(n+1);
  for(int i = 0; i < n; ++i) {
    v_rank[v_ord[i]] = i+1;
  }
  debug(v_ord);
  debug(vi(v_rank.begin() + 1, v_rank.end()));
  root[0] = 0;
  for(int i = 1; i <= n; ++i) {
    root[i] = root[i-1];
    update(root[i], 1, n, v_rank[i], V[i]);
  }
  solve(m, n, 1, n);
  cout << ans << endl;
  return 0;
}
/*
* use std::array instead of std::vector, if u can
* overflow?
* array bounds
*/
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...