Submission #966976

#TimeUsernameProblemLanguageResultExecution timeMemory
966976Soumya1Cake 3 (JOI19_cake3)C++17
100 / 100
692 ms19020 KiB
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int mxN = 2e5 + 5;
const ll inf = 1e18;
int v[mxN], c[mxN], comp[mxN];
ll ans = -inf;
int n, m, lptr, rptr;
int bx[mxN];
ll by[mxN];
void upd(int i, int x, int y) {
  for (; i <= n; i += i & (-i)) bx[i] += x, by[i] += y;
}
pair<int, ll> query(int i) {
  int sx = 0;
  ll sy = 0;
  for (; i > 0; i -= i & (-i)) sx += bx[i], sy += by[i];
  return {sx, sy};
}
ll query() {
  int cur = 0, sumx = 0;
  for (int i = 18; i >= 0; i--) {
    if (cur + (1 << i) <= n && sumx + bx[cur + (1 << i)] < m) {
      sumx += bx[cur + (1 << i)];
      cur += (1 << i);
    }
  }
  auto [fx, fy] = query(cur + 1);
  if (fx != m) return -inf;
  return fy;
}
void add(int i) {
  if (!i) return;
  upd(comp[i], 1, v[i]);
}
void rem(int i) {
  if (!i) return;
  upd(comp[i], -1, -v[i]);
}
ll get(int l, int r) {
  while (lptr > l) add(--lptr);
  while (rptr < r) add(++rptr);
  while (rptr > r) rem(rptr--);
  while (lptr < l) rem(lptr++);
  return query();
}
void solve(int l, int r, int optl, int optr) {
  if (l > r) return;
  int m = (l + r) >> 1;
  int opt;
  ll best = -inf;
  for (int i = max(m, optl); i <= optr; i++) {
    if (best < get(m, i) + 2 * (c[m] - c[i])) {
      best = get(m, i) + 2 * (c[m] - c[i]);
      opt = i;
    }
  }
  ans = max(ans, best);
  if (best == -inf) {
    solve(l, m - 1, optl, optr);
  } else {
    solve(l, m - 1, optl, opt);
    solve(m + 1, r, opt, optr);
  }
}
void testCase() {
  cin >> n >> m;
  vector<int> all;
  for (int i = 1; i <= n; i++) {
    cin >> v[i] >> c[i];
    all.push_back(v[i]);
  }
  {
    vector<int> ord(n);
    iota(ord.begin(), ord.end(), 1);
    sort(ord.begin(), ord.end(), [&](int i, int j) { return c[i] < c[j]; });
    vector<int> nv(n + 1), nc(n + 1);
    for (int i = 0; i < n; i++) nv[i + 1] = v[ord[i]], nc[i + 1] = c[ord[i]];
    for (int i = 1; i <= n; i++) c[i] = nc[i], v[i] = nv[i];
  }
  sort(all.rbegin(), all.rend());
  map<int, int> mp;
  for (int i = n - 1; i >= 0; i--) mp[all[i]] = i + 1;
  vector<bool> done(all.size());
  for (int i = 1; i <= n; i++) {
    comp[i] = mp[v[i]]++;
  }
  solve(1, n, 1, n);
  cout << ans << "\n";
}
int main() {
  ios::sync_with_stdio(false);
  cin.tie(nullptr);
  testCase();
  return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...