Submission #1197150

#TimeUsernameProblemLanguageResultExecution timeMemory
1197150avighnaCake 3 (JOI19_cake3)C++20
100 / 100
682 ms106388 KiB
#include <bits/stdc++.h>

class WaveletTree {
private:
  int low, high;
  WaveletTree *left, *right;
  std::vector<int> map_left;
  std::vector<long long> psum;
  std::vector<int> &decompressed;

public:
  ~WaveletTree() {
    delete left;
    delete right;
  }

  WaveletTree(std::vector<int>::iterator start, std::vector<int>::iterator end,
              int low, int high, std::vector<int> &decompressed)
      : low(low), high(high), left(nullptr), right(nullptr), decompressed(decompressed) {
    if (start >= end) {
      return;
    }

    int mid = low + (high - low) / 2;
    map_left.reserve(end - start + 1);
    psum.reserve(end - start + 1);

    map_left.push_back(0);
    psum.push_back(0);

    for (auto it = start; it != end; ++it) {
      map_left.push_back(map_left.back() + (*it <= mid));
      psum.push_back(psum.back() + decompressed[*it]);
    }

    if (low == high) {
      return;
    }

    auto pivot =
        std::stable_partition(start, end, [mid](int x) { return x <= mid; });

    if (start != pivot) {
      left = new WaveletTree(start, pivot, low, mid, decompressed);
    }
    if (pivot != end) {
      right = new WaveletTree(pivot, end, mid + 1, high, decompressed);
    }
  }

  int kth(int l, int r, int k) {
    if (l > r)
      return -1;
    if (low == high)
      return low;

    int left_l = map_left[l], left_r = map_left[r + 1];
    int left_count = left_r - left_l;

    if (k <= left_count) {
      return left->kth(left_l, left_r - 1, k);
    } else {
      return right->kth(l - left_l, r - left_r, k - left_count);
    }
  }

  std::pair<long long, int> sum(int l, int r, int k) {
    if (l > r || k < low) {
      return {0, 0};
    }

    // If all elements in this range are <= k, return the sum of this range
    if (high <= k) {
      return {psum[r + 1] - psum[l], r - l + 1};
    }

    // Find the partition of the current range into left and right subtrees
    int left_l = map_left[l], left_r = map_left[r + 1];
    auto [a1, b1] = left->sum(left_l, left_r - 1, k);
    auto [a2, b2] = right->sum(l - left_l, r - left_r, k);
    return {a1 + a2, b1 + b2};
  }
};

int main() {
  std::ios_base::sync_with_stdio(false);
  std::cin.tie(nullptr);

  const int64_t inf = 1e15;

  int n, m;
  std::cin >> n >> m;
  std::vector<std::pair<int64_t, int64_t>> a(n);
  for (auto &[c, v] : a) {
    std::cin >> v >> c;
    c <<= 1;
  }
  std::sort(a.begin(), a.end());

  int64_t ans = -inf;
  std::vector<std::pair<int, int>> compress(n);
  for (int i = 0; i < n; ++i) {
    compress[i] = {-a[i].second, i};
  }
  std::sort(compress.begin(), compress.end());
  // coordinate compress b
  std::vector<int> b(n), decompressed(n);
  int j = 0;
  for (int i = 0; i < n; ++i) {
    if (i != 0 and compress[i].first != compress[i - 1].first) {
      decompressed[j] = compress[i - 1].first;
      j++;
    }
    b[compress[i].second] = j;
  }
  decompressed[j] = compress.back().first;
  WaveletTree wv(b.begin(), b.end(), *std::min_element(b.begin(), b.end()),
                 *std::max_element(b.begin(), b.end()), decompressed);
  auto solve = [&](auto &&self, int tl, int tr, int l, int r) -> void {
    if (tl > tr) {
      return;
    }
    int tm = (tl + tr) / 2;
    std::pair<int64_t, int> opt = {-inf, -1};
    for (int i = l; i <= r; ++i) {
      if (i - 1 - (tm + 1) + 1 < m - 2) {
        continue;
      }
      int val = wv.kth(tm + 1, i - 1, m - 2);
      auto [sum, cnt] = wv.sum(tm + 1, i - 1, val - 1);
      sum += (m - 2 - cnt) * decompressed[val];
      sum = -sum;
      opt = std::max(
          opt,
          {a[tm].second + sum + a[i].second - (a[i].first - a[tm].first), i});
    }
    ans = std::max(ans, opt.first);
    self(self, tl, tm - 1, l, opt.second);
    self(self, tm + 1, tr, opt.second, r);
  };
  solve(solve, 0, n - m, 0, n - 1);
  std::cout << ans << '\n';
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...