이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include<bits/stdc++.h>
#define mp make_pair
#define pb push_back
#define ii pair<int,int>
#define all(x) (x).begin(),(x).end()
#define INF 100000000000000000
#define modulo 1000000007
#define mod 998244353
#define int long long int
using namespace std;
struct Node{
int left, right;
int L, R;
int sum;
int cnt;
void set(int a, int b, int c, int d, int e, int f){
left = a;
right = b;
L = c;
R = d;
sum = e;
cnt = f;
}
Node(){}
};
vector<Node> seg(7000000);
vector<int> root(200005);
vector<ii> arr;
vector<ii> util;
vector<int> ptr(200005);
int ind = 0;
int S;
int new_node(){
return ind++;
}
int getsum(int i, int j){
return seg[j].sum - seg[i].sum;
}
int getcnt(int i, int j){
return seg[j].cnt - seg[i].cnt;
}
void build(int curr, int l, int r){
if(r == l){
seg[curr].set(-1, -1, l, r, 0, 0);
return;
}
seg[curr].set(new_node(), new_node(), l, r, 0, 0);
build(seg[curr].left, l, (l + r) / 2);
build(seg[curr].right, (l + r) / 2 + 1, r);
}
int pers(int j, int l, int r, int x){
int curr = new_node();
if(l != r){
int mid = (l + r) / 2;
int q, w;
if(x <= mid) {
q = pers(seg[j].left ,l, mid, x);
w = seg[j].right;
}
else{
q = seg[j].left;
w = pers(seg[j].right ,mid + 1, r, x);
}
seg[curr].set(q, w, l, r, seg[q].sum + seg[w].sum, seg[q].cnt + seg[w].cnt);
}
else{
seg[curr].set(-1, -1, l, r, util[l].first, 1);
}
return curr;
}
int query(int j1, int l, int r, int x, int j2){
if(getcnt(j1, j2) == 0 || x == 0)return 0;
if(l == r) return getsum(j1, j2);
int q = seg[j1].left;
int w = seg[j1].right;
int Q = seg[j2].left;
int W = seg[j2].right;
if(x == getcnt(j1, j2)){
return getsum(j1, j2);
}
return query(q, l, (l + r) / 2, x, Q) + query(w, (l + r) / 2 + 1, r, max(x - getcnt(q, Q), 0ll), W);
}
void print(int j, int l, int r){
cout << "("<<l<<","<<r<<") = "<< seg[j].sum << " " << seg[j].cnt << "\n";
if(l == r){
return;
}
print(seg[j].left, l, (l + r) / 2);
print(seg[j].right, (l + r) / 2 + 1, r);
}
int compute(int l, int r, int optl, int optr, int m){
if(l > r) return -INF;
int mid = (l + r) / 2;
int w = -INF;
int opt = optl;
for(int k = optl; k <= min(optr, mid); k++){
int q = -INF;
if(mid - k + 1 >= m){
q = query((k > 0 ? root[k - 1] : 0), 0, S - 1, m - 1, root[mid - 1]) - 2 * (arr[mid].first - arr[k].first) + arr[mid].second;
}
if(q > w){
w = q;
opt = k;
}
}
return max(w, max(compute(l, mid - 1, optl, opt, m), compute(mid + 1, r, opt, optr, m)));
}
int32_t main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n, m;
cin >> n >> m;
for(int i = 0; i < n; i++){
int x, y;
cin >> x >> y;
arr.pb({y, x});
}
sort(all(arr));
for(int i = 0; i < n; i++){
util.pb({arr[i].second, i});
}
sort(all(util), greater<ii>());
for(int i = 0; i < n; i++) ptr[util[i].second] = i;
S = (1 << (int)ceil(log2(n)));
build(new_node(), 0, S - 1);
for(int i = 0; i < n; i++){
root[i] = pers((i > 0 ? root[i - 1] : 0), 0, S - 1, ptr[i]);
}
cout << compute(0, n - 1, 0, n - 1, m);
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |