이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
using namespace std;
int n,m;
vector<pair<int,int> > pieces;
const int COEFF = 32;
const int MXN = 200005;
int segT[COEFF*MXN];
long long qs[MXN];
long long qsegT[COEFF*MXN];
int lb[COEFF*MXN];
int rb[COEFF*MXN];
int llink[COEFF*MXN];
int rlink[COEFF*MXN];
int nodesz;
inline int newnode(int o = -1){
int c = ++nodesz;
if(o != -1){
qsegT[c] = qsegT[o];
segT[c] = segT[o];
lb[c] = lb[o];
rb[c] = rb[o];
llink[c] = llink[o];
rlink[c] = rlink[o];
}
return c;
}
int build(int l, int r){
int c = newnode();
lb[c] = l;
rb[c] = r;
if(l == r) return c;
int k = (l + r)/2;
llink[c] = build(l, k);
rlink[c] = build(k+1, r);
return c;
}
int update(int o, int idx, int amt){
int c = newnode(o);
if(lb[c] == rb[c]){
segT[c] = 1;
qsegT[c] = amt;
return c;
}
int k = (lb[c] + rb[c])/2;
if(idx <= k){
llink[c] = update(llink[o], idx, amt);
}else{
rlink[c] = update(rlink[o], idx, amt);
}
segT[c] = segT[llink[c]] + segT[rlink[c]];
qsegT[c] = qsegT[llink[c]] + qsegT[rlink[c]];
return c;
}
long long sumdescent(int c, int o, int ord){
if(ord == 0) return 0ll;
if(ord == segT[c] - segT[o]) return qsegT[c] - qsegT[o];
if(lb[c] == rb[c]) return qsegT[c] - qsegT[o];
int lcnt = segT[llink[c]] - segT[llink[o]];
if(lcnt >= ord) return sumdescent(llink[c], llink[o], ord);
else return sumdescent(rlink[c], rlink[o], ord - lcnt) + qsegT[llink[c]] - qsegT[llink[o]];
}
vector<int> rootlist;
long long query(int l, int r){
int croot = rootlist[r];
int items = r-l+1 - (m-2);
long long qr = sumdescent(rootlist[r], rootlist[l-1], items);
long long rsq = qs[r] - qs[l-1];
return rsq - qr;
}
long long solve(int l, int r, int lb, int rb){
int mid = (l + r)/2;
long long ans = -1e18;
int mindex = -1;
for(int i = lb; i <= min(mid-m+1,rb); i++){
long long cur = query(i+1, mid-1) + 1ll * pieces[i].first + 1ll * pieces[mid].first + 2ll * pieces[i].second - 2ll * pieces[mid].second;
if(cur > ans){
ans = cur;
mindex = i;
}
}
if(l <= mid-1) ans = max(ans, solve(l, mid-1, lb, mindex == -1 ? rb : mindex));
if(mid+1 <= r) ans = max(ans, solve(mid+1, r, mindex == -1 ? lb : mindex, rb));
return ans;
}
int mp[MXN];
int main(){
scanf("%d%d",&n,&m);
int v,c;
vector<pair<int,int> > rankofv;
for(int i = 0; i < n; i++){
scanf("%d%d",&v,&c);
pieces.emplace_back(v,c);
}
stable_sort(pieces.begin(), pieces.end(), [](pair<int,int> x, pair<int,int> y){
return x.second < y.second;
});
for(int i = 0; i < n; i++){
rankofv.emplace_back(pieces[i].first,i);
}
sort(rankofv.begin(), rankofv.end());
for(int i = 0; i < rankofv.size(); i++){
mp[rankofv[i].second] = i;
}
qs[0] = pieces[0].first;
for(int i = 1; i < n; i++){
qs[i] = qs[i-1] + 1ll * pieces[i].first;
}
int initroot = build(0,n-1);
int lastroot = initroot;
for(int i = 0; i < n; i++){
int idx = mp[i];
//printf("DBG %d %d %lld\n",idx,pieces[i].first,qsegT[lastroot]);
rootlist.push_back(update(lastroot, idx, pieces[i].first));
lastroot = rootlist.back();
}
long long ans = solve(m-1,n-1,0,n-1);
printf("%lld\n",ans);
return 0;
}
컴파일 시 표준 에러 (stderr) 메시지
cake3.cpp: In function 'long long int query(int, int)':
cake3.cpp:64:9: warning: unused variable 'croot' [-Wunused-variable]
int croot = rootlist[r];
^~~~~
cake3.cpp: In function 'int main()':
cake3.cpp:101:22: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
for(int i = 0; i < rankofv.size(); i++){
~~^~~~~~~~~~~~~~~~
cake3.cpp:87:10: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
scanf("%d%d",&n,&m);
~~~~~^~~~~~~~~~~~~~
cake3.cpp:91:14: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
scanf("%d%d",&v,&c);
~~~~~^~~~~~~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |