#include <bits/stdc++.h>
class WaveletTree {
private:
int low, high;
WaveletTree *left, *right;
std::vector<int> map_left;
std::vector<long long> psum;
public:
~WaveletTree() {
delete left;
delete right;
}
WaveletTree(std::vector<int>::iterator start, std::vector<int>::iterator end,
int low, int high)
: low(low), high(high), left(nullptr), right(nullptr) {
if (start >= end) {
return;
}
int mid = low + (high - low) / 2;
map_left.reserve(end - start + 1);
psum.reserve(end - start + 1);
map_left.push_back(0);
psum.push_back(0);
for (auto it = start; it != end; ++it) {
map_left.push_back(map_left.back() + (*it <= mid));
psum.push_back(psum.back() + *it);
}
if (low == high) {
return;
}
auto pivot =
std::stable_partition(start, end, [mid](int x) { return x <= mid; });
if (start != pivot) {
left = new WaveletTree(start, pivot, low, mid);
}
if (pivot != end) {
right = new WaveletTree(pivot, end, mid + 1, high);
}
}
int kth(int l, int r, int k) {
if (l > r)
return -1;
if (low == high)
return low;
int left_l = map_left[l], left_r = map_left[r + 1];
int left_count = left_r - left_l;
if (k <= left_count) {
return left->kth(left_l, left_r - 1, k);
} else {
return right->kth(l - left_l, r - left_r, k - left_count);
}
}
long long sum(int l, int r, int k, int cnt) {
if (l > r || k < low || cnt <= 0) {
return 0;
}
if (high <= k) {
int take = std::min(cnt, r - l + 1);
return psum[l + take] - psum[l];
}
int left_l = map_left[l], left_r = map_left[r + 1];
int left_cnt = left_r - left_l;
int take_left = std::min(cnt, left_cnt);
long long left_sum = left ? left->sum(left_l, left_r - 1, k, take_left) : 0;
int remaining = cnt - take_left;
long long right_sum =
right ? right->sum(l - left_l, r - left_r, k, remaining) : 0;
return left_sum + right_sum;
}
};
int main() {
std::ios_base::sync_with_stdio(false);
std::cin.tie(nullptr);
const int64_t inf = 1e15;
int n, m;
std::cin >> n >> m;
std::vector<std::pair<int64_t, int64_t>> a(n);
for (auto &[c, v] : a) {
std::cin >> v >> c;
c <<= 1;
}
std::sort(a.begin(), a.end());
int64_t ans = -inf;
std::vector<int> b(n);
for (int i = 0; i < n; ++i) {
b[i] = -a[i].second;
}
WaveletTree wv(b.begin(), b.end(), *std::min_element(b.begin(), b.end()),
*std::max_element(b.begin(), b.end()));
auto solve = [&](auto &&self, int tl, int tr, int l, int r) -> void {
if (tl > tr) {
return;
}
int tm = (tl + tr) / 2;
std::pair<int64_t, int> opt = {-inf, -1};
for (int i = l; i <= r; ++i) {
if (i - 1 - (tm + 1) + 1 < m - 2) {
continue;
}
int val = wv.kth(tm + 1, i - 1, m - 2);
int64_t sum = -wv.sum(tm + 1, i - 1, val, m - 2);
{
std::vector<int> srt;
int64_t sum_bf = 0;
for (int j = tm + 1; j <= i - 1; ++j) {
srt.push_back(a[j].second);
}
std::sort(srt.rbegin(), srt.rend());
for (int i = 0; i < m - 2; ++i) {
sum_bf += srt[i];
}
}
opt = std::max(
opt,
{a[tm].second + sum + a[i].second - (a[i].first - a[tm].first), i});
}
ans = std::max(ans, opt.first);
self(self, tl, tm - 1, l, opt.second);
self(self, tm + 1, tr, opt.second, r);
};
solve(solve, 0, n - m, 0, n - 1);
std::cout << ans << '\n';
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |