Submission #1197150

#TimeUsernameProblemLanguageResultExecution timeMemory
1197150avighnaCake 3 (JOI19_cake3)C++20
100 / 100
682 ms106388 KiB
#include <bits/stdc++.h> class WaveletTree { private: int low, high; WaveletTree *left, *right; std::vector<int> map_left; std::vector<long long> psum; std::vector<int> &decompressed; public: ~WaveletTree() { delete left; delete right; } WaveletTree(std::vector<int>::iterator start, std::vector<int>::iterator end, int low, int high, std::vector<int> &decompressed) : low(low), high(high), left(nullptr), right(nullptr), decompressed(decompressed) { if (start >= end) { return; } int mid = low + (high - low) / 2; map_left.reserve(end - start + 1); psum.reserve(end - start + 1); map_left.push_back(0); psum.push_back(0); for (auto it = start; it != end; ++it) { map_left.push_back(map_left.back() + (*it <= mid)); psum.push_back(psum.back() + decompressed[*it]); } if (low == high) { return; } auto pivot = std::stable_partition(start, end, [mid](int x) { return x <= mid; }); if (start != pivot) { left = new WaveletTree(start, pivot, low, mid, decompressed); } if (pivot != end) { right = new WaveletTree(pivot, end, mid + 1, high, decompressed); } } int kth(int l, int r, int k) { if (l > r) return -1; if (low == high) return low; int left_l = map_left[l], left_r = map_left[r + 1]; int left_count = left_r - left_l; if (k <= left_count) { return left->kth(left_l, left_r - 1, k); } else { return right->kth(l - left_l, r - left_r, k - left_count); } } std::pair<long long, int> sum(int l, int r, int k) { if (l > r || k < low) { return {0, 0}; } // If all elements in this range are <= k, return the sum of this range if (high <= k) { return {psum[r + 1] - psum[l], r - l + 1}; } // Find the partition of the current range into left and right subtrees int left_l = map_left[l], left_r = map_left[r + 1]; auto [a1, b1] = left->sum(left_l, left_r - 1, k); auto [a2, b2] = right->sum(l - left_l, r - left_r, k); return {a1 + a2, b1 + b2}; } }; int main() { std::ios_base::sync_with_stdio(false); std::cin.tie(nullptr); const int64_t inf = 1e15; int n, m; std::cin >> n >> m; std::vector<std::pair<int64_t, int64_t>> a(n); for (auto &[c, v] : a) { std::cin >> v >> c; c <<= 1; } std::sort(a.begin(), a.end()); int64_t ans = -inf; std::vector<std::pair<int, int>> compress(n); for (int i = 0; i < n; ++i) { compress[i] = {-a[i].second, i}; } std::sort(compress.begin(), compress.end()); // coordinate compress b std::vector<int> b(n), decompressed(n); int j = 0; for (int i = 0; i < n; ++i) { if (i != 0 and compress[i].first != compress[i - 1].first) { decompressed[j] = compress[i - 1].first; j++; } b[compress[i].second] = j; } decompressed[j] = compress.back().first; WaveletTree wv(b.begin(), b.end(), *std::min_element(b.begin(), b.end()), *std::max_element(b.begin(), b.end()), decompressed); auto solve = [&](auto &&self, int tl, int tr, int l, int r) -> void { if (tl > tr) { return; } int tm = (tl + tr) / 2; std::pair<int64_t, int> opt = {-inf, -1}; for (int i = l; i <= r; ++i) { if (i - 1 - (tm + 1) + 1 < m - 2) { continue; } int val = wv.kth(tm + 1, i - 1, m - 2); auto [sum, cnt] = wv.sum(tm + 1, i - 1, val - 1); sum += (m - 2 - cnt) * decompressed[val]; sum = -sum; opt = std::max( opt, {a[tm].second + sum + a[i].second - (a[i].first - a[tm].first), i}); } ans = std::max(ans, opt.first); self(self, tl, tm - 1, l, opt.second); self(self, tm + 1, tr, opt.second, r); }; solve(solve, 0, n - m, 0, n - 1); std::cout << ans << '\n'; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...