#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ll long long
#define F first
#define S second
#define pb push_back
#define pii pair<int, int>
const int N = 2e5 + 5;
const int M = 5e6 + 5;
int n, m;
pii a[N];
vector<int> cmp;
struct Node {
ll sum;
int cnt, lef, rig;
Node() {}
Node(ll _sum, int _cnt, int _lef, int _rig) : sum(_sum), cnt(_cnt), lef(_lef), rig(_rig) {}
} st[M];
int ver[N];
int numNode;
ll res = 0;
void pull(int cur_node) {
st[cur_node].cnt = st[st[cur_node].lef].cnt + st[st[cur_node].rig].cnt;
st[cur_node].sum = st[st[cur_node].lef].sum + st[st[cur_node].rig].sum;
}
int update(int id, int l, int r, int i) {
if (l == r) {
st[++numNode] = Node(cmp[l - 1], 1, 0, 0);
return numNode;
}
int mid = (l + r) / 2;
int cur_id = ++numNode;
if (i <= mid) {
st[cur_id].lef = update(st[id].lef, l, mid, i);
st[cur_id].rig = st[id].rig;
pull(cur_id);
}
else {
st[cur_id].rig = update(st[id].rig, mid + 1, r, i);
st[cur_id].lef = st[id].lef;
pull(cur_id);
}
return cur_id;
}
ll walk(int old_id, int id, int l, int r, int k) {
if (l == r) {
return 1ll * min(k, st[id].cnt - st[old_id].cnt) * cmp[l - 1];
}
int mid = (l + r) / 2;
int cnt_r = st[st[id].rig].cnt - st[st[old_id].rig].cnt;
if (cnt_r >= k) {
return walk(st[old_id].rig, st[id].rig, mid + 1, r, k);
}
return st[st[id].rig].sum - st[st[old_id].rig].sum + walk(st[old_id].lef, st[id].lef, l, mid, k - cnt_r);
}
ll cost(int l, int r) {
ll ret = (m > 2 ? walk(ver[l], ver[r - 1], 1, n, m - 2) : 0ll) + 2ll * a[l].F + a[l].S - 2ll * a[r].F + a[r].S;
// cout << l << ' ' << r << ' ' << a[l].F + 2 * a[l].S + a[r].F - 2 * a[r].S << endl;
return ret;
}
void cal(int l, int r, int optl, int optr) {
if (l > r) return;
int mid = (l + r) / 2;
int best = optr;
ll bestVal = 0;
for (int i = max(optl, mid + m - 1); i <= optr; i++) {
ll cur = cost(mid, i);
if (cur > bestVal) {
bestVal = cur;
best = i;
}
}
res = max(res, bestVal);
cal(l, mid - 1, optl, best);
cal(mid + 1, r, best, optr);
}
signed main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cin >> n >> m;
for (int i = 1; i <= n; i++) {
cin >> a[i].S >> a[i].F;
cmp.pb(a[i].S);
}
sort(a + 1, a + 1 + n);
sort(cmp.begin(), cmp.end());
cmp.erase(unique(cmp.begin(), cmp.end()), cmp.end());
for (int i = 1; i <= n; i++) {
// cout << a[i].F << ' ' << a[i].S << endl;
int x = upper_bound(cmp.begin(), cmp.end(), a[i].S) - cmp.begin();
ver[i] = update(ver[i - 1], 1, n, x);
}
cal(1, n, 1, n);
cout << res;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |