이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ii = pair<int, int>;
using vi = vector<int>;
#define all(v) begin(v), end(v)
#define sz(v) (int)(v).size()
#define fi first
#define se second
#ifdef LOCAL
#include "debug.h"
#else
#define debug(...)
#endif
const int N = 2e5 + 5;
struct node {
ll sum;
int cnt, l, r;
node () {
cnt = 0, sum = 0;
l = 0, r = 0; // careful
}
};
node t[N];
int ptr = 0;
int n, m, root[N];
ll V[N], C[N];
void update(int& u, int b, int e, int pos, ll val) {
t[++ptr] = t[u];
t[u = ptr].sum += val;
t[u].cnt++;
if(b == e) return;
int mid = (b + e) >> 1;
if(pos <= mid)
update(t[u].l, b, mid, pos, val);
else
update(t[u].r, mid+1, e, pos, val);
// pull? no
}
ll largest_ksum(int u, int v, int b, int e, int k) { // [v] \ [u]
// assumption: v.cnt - u.cnt >= m
if(t[v].cnt - t[u].cnt == k) return t[v].sum - t[u].sum;
// assumption:
// * v.cnt - u.cnt > m
// * b < e
// if left half has >= m, then just return that result
int mid = (b + e) >> 1;
if(t[t[v].l].cnt - t[t[u].l].cnt >= k)
return largest_ksum(t[u].l, t[v].l, b, mid, k);
// otherwise left half has < m, take that
// and take the rest from right half
else {
ll ret = t[t[v].l].sum - t[t[u].l].sum;
int took = t[t[v].l].cnt - t[t[u].l].cnt;
return ret + largest_ksum(t[u].r, t[v].r, mid+1, e, k - took);
}
}
ll cost(int l, int r) {
return largest_ksum(root[l], root[r-1], 1, n, m-2) + V[l] + V[r] - 2*(C[r] - C[l]);
// return largest_ksum(root[l-1], root[r], 1, n, m) - 2LL * (C[r] - C[l]);
}
ll ans = LLONG_MIN;
void solve(int l, int r, int b, int e) { // solve for [l, r] in search space [b, e]
if(l > r) return;
pair<ll,int> opt(LLONG_MIN, -1);
int mid = (l + r) >> 1;
for(int i = b; i <= min(e, mid-m+1); ++i) {
opt = max(opt, {cost(i, mid), i});
}
assert(~opt.se);
ans = max(ans, opt.fi);
solve(l, mid-1, b, opt.se);
solve(mid+1, r, opt.se, e);
}
int main(int argc, char const *argv[])
{
cin.sync_with_stdio(0); cin.tie(0);
cin.exceptions(cin.failbit);
#ifdef LOCAL
freopen("in", "r", stdin);
#endif
cin >> n >> m;
for(int i = 1; i <= n; ++i) {
cin >> V[i] >> C[i];
}
vi p(n); iota(all(p), 1);
sort(all(p), [](int i,int j){ return C[i] < C[j]; });
{
vi VV(n), CC(n);
for(int i = 0; i < n; ++i) {
VV[i] = (int)V[p[i]];
CC[i] = (int)C[p[i]];
}
for(int i = 1; i <= n; ++i) {
V[i] = VV[i-1];
C[i] = CC[i-1];
}
}
debug(vi(C+1, C+1+n));
debug(vi(V+1, V+1+n));
vi v_ord(n); iota(all(v_ord), 1);
sort(all(v_ord), [](int i, int j){ return V[i] > V[j]; });
vi v_rank(n+1);
for(int i = 0; i < n; ++i) {
v_rank[v_ord[i]] = i+1;
}
debug(v_ord);
debug(vi(v_rank.begin() + 1, v_rank.end()));
root[0] = 0;
for(int i = 1; i <= n; ++i) {
root[i] = root[i-1];
update(root[i], 1, n, v_rank[i], V[i]);
}
solve(m, n, 1, n);
cout << ans << endl;
return 0;
}
/*
* use std::array instead of std::vector, if u can
* overflow?
* array bounds
*/
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |