이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include"holiday.h"
#include <bits/stdc++.h>
using namespace std;
#define _all(T) T.begin(),T.end()
#define ll long long
struct DS{
priority_queue<ll,vector<ll>,greater<ll>> pq;
ll sum = 0;
DS(){
sum = 0;
}
void init(){
sum = 0;
while(!pq.empty())pq.pop();
}
void add(ll k){
pq.push(k);
sum += k;
}
void shape(int len){
while(!pq.empty()&&(ll)pq.size()>len){
sum -= pq.top();
pq.pop();
}
assert(pq.empty()||pq.size()<=len);
return;
}
};
namespace ZERO{
const ll mxn = 1e5+10;
DS ds;
ll GO(int n,int s,int d,int arr[]){
ds.init();
ll ans = 0;
for(int i = 0;i<=min(d,n-1);i++){
ds.add(arr[i]);
int rest = d-i;
ds.shape(d-i);
ans = max(ans,ds.sum);
}
return ans;
}
}
namespace BRUTE{
const ll mxn = 3030;
int tr1[mxn],tr2[mxn];
ll arr[mxn];
DS ds;
int n,s,d;
ll calc(int e){
ll re = 0;
ds.init();
for(int i = e;i>s;i--)ds.add(arr[i]);
ll tans = 0;
for(int i = s;i>=0;i--){
ds.add(arr[i]);
int rest = d-((e-s)*2+(s-i));
ds.shape(rest);
if(tans<ds.sum){
tans = ds.sum;
tr1[e] = i;
}
}
re = max(re,tans);
ds.init();
tans = 0;
for(int i = e;i>s;i--)ds.add(arr[i]);
for(int i = s;i>=0;i--){
ds.add(arr[i]);
int rest = d-((e-s)+(s-i)*2);
ds.shape(rest);
if(tans<ds.sum){
tans = ds.sum;
tr2[e] = i;
}
}
re = max(re,tans);
return re;
}
void check(int tr[]){
vector<int> v;
for(int i = 0;i<n;i++)if(tr[i] != -1)v.push_back(i);
for(int i = 1;i<v.size();i++){
assert(tr[v[i-1]]<=tr[v[i]]);
assert(v[i-1] == v[i]-1);
}
return;
}
ll GO(int nn,int ss,int dd,int tarr[]){
memset(tr1,-1,sizeof(tr1));
memset(tr2,-1,sizeof(tr2));
n = nn,s = ss,d = dd;
for(int i = 0;i<n;i++)arr[i] = tarr[i];
ll ans = 0;
for(int i = s;i<min(s+d,n);i++)ans = max(ans,calc(i));
check(tr1);
check(tr2);
return ans;
}
}
#define pii pair<int,int>
#define pll pair<ll,ll>
#define fs first
#define sc second
namespace DC{
const int mxn = 1e5+1;
vector<ll> all;
struct node{
int pl,pr;
ll sum;
int cnt;
node(){
pl = pr = cnt = sum = 0;
}
};
const int LEN = 2e6+10;
struct SEG{
#define mid ((l+r)>>1)
#define ls seg[now].pl
#define rs seg[now].pr
node seg[LEN];
int ptr = 0;
SEG(){}
int newnode(int k = 0){
assert(ptr+1<LEN);
seg[++ptr] = seg[k];
return ptr;
}
int modify(int now,int l,int r,int p,ll v){
now = newnode(now);
if(l == r){
seg[now].cnt++;
seg[now].sum += v;
return now;
}
if(mid>=p)ls = modify(ls,l,mid,p,v);
else rs = modify(rs,mid+1,r,p,v);
seg[now].sum = seg[ls].sum+seg[rs].sum;
seg[now].cnt = seg[ls].cnt+seg[rs].cnt;
return now;
}
ll getbig(int s,int e,int l,int r,int tar){
if(l == r){
return all[l]*min(tar,seg[e].cnt-seg[s].cnt);
}
if(seg[seg[e].pr].cnt-seg[seg[s].pr].cnt>=tar)return getbig(seg[s].pr,seg[e].pr,mid+1,r,tar);
else return seg[seg[e].pr].sum-seg[seg[s].pr].sum
+getbig(seg[s].pl,seg[e].pl,l,mid,tar-(seg[seg[e].pr].cnt-seg[seg[s].pr].cnt));
}
#undef ls
#undef rs
#undef mid
};
ll arr[mxn];
SEG seg;
int rts[mxn];
int n,st,d;
vector<int> v1,v2;
ll ans = 0;
ll calc1(int s,int e){
assert(s<=st&&st<=e);
int rest = d-((e-st)*2+(st-s));
if(rest<0)return -1;
return seg.getbig(rts[s-1],rts[e],0,all.size(),rest);
}
void dc1(int tl,int tr,int l,int r){
int mid = (l+r)>>1;
pll mx = pll(-1,-1);
for(int i = tl;i<=min(mid,tr);i++){
mx = max(mx,pll(calc1(i,mid),-i));
}
assert(mx.fs != -1);
ans = max(ans,mx.fs);
mx.sc = -mx.sc;
if(mid != l)dc1(tl,mx.sc,l,mid-1);
if(mid != r)dc1(mx.sc,tr,mid+1,r);
return;
}
ll calc2(int s,int e){
int rest = d-((e-st)+(st-s)*2);
if(rest<0)return -1;
return seg.getbig(rts[s-1],rts[e],0,all.size(),rest);
}
void dc2(int tl,int tr,int l,int r){
int mid = (l+r)>>1;
pll mx = pll(-1,-1);
for(int i = tl;i<=min(mid,tr);i++){
mx = max(mx,pll(calc2(i,mid),-i));
}
assert(mx.fs != -1);
ans = max(ans,mx.fs);
mx.sc = -mx.sc;
if(mid != l)dc2(tl,mx.sc,l,mid-1);
if(mid != r)dc2(mx.sc,tr,mid+1,r);
return;
}
ll GO(int nn,int ss,int dd,int tarr[]){
n = nn,st = ss,d = dd;
st++;
all.push_back(0);
for(int i= 1;i<=n;i++)arr[i] = tarr[i-1],all.push_back(arr[i]);
sort(_all(all));all.resize(unique(_all(all))-all.begin());
for(int i = 0;i<=n;i++)arr[i] = lower_bound(_all(all),arr[i])-all.begin();
for(int i = 1;i<=n;i++){
rts[i] = seg.modify(rts[i-1],0,all.size(),arr[i],all[arr[i]]);
}
int r1 = st,r2 = st;
for(int i = st;i<=n;i++){
if((i-st)*2<=d)r1 = i;
if(i-st<=d)r2 = i;
}
dc1(1,st,st,r1);
dc2(1,st,st,r2);
return ans;
}
}
long long int findMaxAttraction(int n, int start, int d, int attraction[]) {
/*
auto t1 = DC::GO(n,start,d,attraction),t2 = BRUTE::GO(n,start,d,attraction);
cerr<<t1<<','<<t2<<endl;
assert(t1 == t2);
return t1;
*/
return DC::GO(n,start,d,attraction);
if(start == 0)return ZERO::GO(n,start,d,attraction);
else if(n<=3000)return BRUTE::GO(n,start,d,attraction);
else return DC::GO(n,start,d,attraction);
}
컴파일 시 표준 에러 (stderr) 메시지
In file included from /usr/include/c++/10/cassert:44,
from /usr/include/x86_64-linux-gnu/c++/10/bits/stdc++.h:33,
from holiday.cpp:3:
holiday.cpp: In member function 'void DS::shape(int)':
holiday.cpp:28:31: warning: comparison of integer expressions of different signedness: 'std::priority_queue<long long int, std::vector<long long int>, std::greater<long long int> >::size_type' {aka 'long unsigned int'} and 'int' [-Wsign-compare]
28 | assert(pq.empty()||pq.size()<=len);
| ~~~~~~~~~^~~~~
holiday.cpp: In function 'long long int ZERO::GO(int, int, int, int*)':
holiday.cpp:41:8: warning: unused variable 'rest' [-Wunused-variable]
41 | int rest = d-i;
| ^~~~
holiday.cpp: In function 'void BRUTE::check(int*)':
holiday.cpp:92:18: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
92 | for(int i = 1;i<v.size();i++){
| ~^~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |