# | 제출 시각 | 아이디 | 문제 | 언어 | 결과 | 실행 시간 | 메모리 |
---|---|---|---|---|---|---|---|
1081250 | anango | 휴가 (IOI14_holiday) | C++17 | 0 ms | 0 KiB |
이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include"holiday.h"
#include <bits/stdc++.h>
#define vector basic_string
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
typedef tree<int,null_type,less<int>,rb_tree_tag,tree_order_statistics_node_update> ordered_set;
#define int long long
int INF = 1LL<<62;
const int sz = 262144;
vector<int> revcoords;
class SegmentTree {
public:
vector<int> tree;
int n;
ordered_set current_indices;
SegmentTree(int numelem) {
tree=vector<int>(sz,0);
n=numelem;
//cout << "making segtree " << n << endl;
}
void update(int v, int l, int r, int tl, int tr, int addend) {
if (l>tr || r<tl) return;
if (l<=tl && tr<=r) {
assert(l==r);
tree[v]+=addend;
return;
}
int m = (tl+tr)/2;
update(2*v,l,r,tl,m,addend);
update(2*v+1,l,r,m+1,tr,addend);
tree[v] = tree[2*v]+tree[2*v+1];
}
int query(int v, int l, int r, int tl, int tr) {
if (l>tr || r<tl) return 0;
if (l<=tl && tr<=r) {
//cout << "qend " << tl <<" " << tr <<" " << tree[v] << endl;
return tree[v];
}
int m = (tl+tr)/2;
return query(2*v,l,r,tl,m)+query(2*v+1,l,r,m+1,tr);
}
void ins(int index, int weight) {
//weight must be weights[index]
int rc = revcoords[index];
current_indices.insert(rc);
//cout << "updating " << rc <<" " << weight << endl;
update(1,rc,rc,0,131071,weight);
}
void era(int index, int weight) {
//weight must be weights[index]
int rc = revcoords[index];
current_indices.erase(rc);
update(1,rc,rc,0,131071,-weight);
}
int get_sum_maximals(int k) {
//(1 indexed, so k=1 means you just want max etc)
if (current_indices.size()<k) {
return tree[1]; //sum of everything
}
if (k<=0) return 0;
int reqcid = current_indices.size();
reqcid-=k;
int req_index = *current_indices.find_by_order(reqcid);
int ans = query(1,req_index,131071,0,131071);
//cout << "getting " << k <<" " << current_indices.size() <<" " << reqcid <<" " << req_index << " " << ans << endl;
return ans;
}
};
vector<int> optimal_k;
vector<int> answers;
int state = -1; //state i means it contains 0 to i inclusive
SegmentTree st(1);
void solve(int l, int r, vector<int> &weights) {
if (l>r) return;
//solve for all d
int m = (l+r)/2;
//cout << "solving " << l <<" " << m <<" " << r << endl;
int vleft = optimal_k[l-1];
int vright = optimal_k[r+1];
//find optimal_k[m], by going from vleft to vright
while (state>vleft) {
st.era(state,weights[state]);
state--;
}
while (state<vleft) {
state++;
st.ins(state,weights[state]);
}
//state is now equal to m
int opk = -1;
int curbest = -1;
for (int k=vleft; k<=vright; k++) {
if (m-k-1<=0) continue;
int kans = st.get_sum_maximals(m-k-1);
//cout << "doing " << k <<" " << kans << " " << m-k-1 << " " << vleft << " " << curbest <<" " << opk << endl;
if (kans>curbest) {
curbest = kans;
opk = k;
}
assert(state==k);
state++;
st.ins(state,weights[state]);
}
optimal_k[m] = opk;
answers[m] = curbest;
//cout << "solved " << l <<" " << m<<" " << r <<" " << vleft <<" "<< vright << " " << opk <<" " << curbest << endl;
if (l<r) {
solve(l,m-1,weights);
solve(m+1,r,weights);
}
return;
}
vector<int> solve_for_all_d(int n, vector<int> weights) {
st = SegmentTree(n);
//ignore the middle itself
revcoords=vector<int>(n,-1);
vector<int> coords(n); iota(coords.begin(), coords.end(), (int)0);
for (int i=0; i<n; i++) {
//cout << weights[i] <<" ";
}
//cout << endl << endl;
sort(coords.begin(), coords.end(),[&](const int i1, const int i2) {
return weights[i1]<weights[i2];
});
for (int i=0; i<coords.size(); i++) {
//revcoords is what position this index takes in the sorted list by weight
revcoords[coords[i]] = i;
}
//find the optimal k for each d using divide and conquer
//taking 0 through k
//consider d<=2n
optimal_k = vector<int>(2*n+1,-1);
optimal_k[2] = 0;
optimal_k[2*n] = n-1;
answers=vector<int>(2*n+1,-1);
answers[0] = answers[1] = 0;
answers[2] = weights[0];
answers[2*n] = accumulate(weights.begin(), weights.end(), (int)0);
state = -1;
solve(3,2*n-1,weights);
for (int i=0; i<2*n+1; i++) {
//cout << i <<" " << optimal_k[i] <<" " << answers[i] << endl;
}
return answers;
}
void solvedouble(int l, int r, vector<int> &weights) {
if (l>r) return;
//solve for all d
int m = (l+r)/2;
//cout << "solving " << l <<" " << m <<" " << r << endl;
int vleft = optimal_k[l-1];
int vright = optimal_k[r+1];
//find optimal_k[m], by going from vleft to vright
while (state>vleft) {
st.era(state,weights[state]);
state--;
}
while (state<vleft) {
state++;
st.ins(state,weights[state]);
}
//state is now equal to m
int opk = -1;
int curbest = -1;
for (int k=vleft; k<=vright; k++) {
if (m-2*k-2<=0) continue;
int kans = st.get_sum_maximals(m-2*k-2);
//cout << "doing " << k <<" " << kans << " " << m-2*k-1 << " " << vleft << " " << curbest <<" " << opk << " " << state << endl;
if (kans>curbest) {
curbest = kans;
opk = k;
}
assert(state==k);
state++;
st.ins(state,weights[state]);
}
optimal_k[m] = opk;
answers[m] = curbest;
//cout << "solved " << l <<" " << m<<" " << r <<" " << vleft <<" "<< vright << " " << opk <<" " << curbest << endl;
if (l<r) {
solvedouble(l,m-1,weights);
solvedouble(m+1,r,weights);
}
return;
}
vector<int> solve_for_all_d_double(int n, vector<int> weights) {
if (n==0) {
return {0};
}
if (n==1) {
return {0,0,0,weights[0]};
}
st = SegmentTree(n);
//ignore the middle itself
revcoords=vector<int>(n,-1);
vector<int> coords(n); iota(coords.begin(), coords.end(), (int)0);
for (int i=0; i<n; i++) {
//cout << weights[i] <<" ";
}
//cout << endl << endl;
sort(coords.begin(), coords.end(),[&](const int i1, const int i2) {
return weights[i1]<weights[i2];
});
for (int i=0; i<coords.size(); i++) {
//revcoords is what position this index takes in the sorted list by weight
revcoords[coords[i]] = i;
}
//find the optimal k for each d using divide and conquer
//taking 0 through k
//consider d<=2n
optimal_k = vector<int>(3*n+1,-1);
optimal_k[3] = 0;
optimal_k[3*n] = n-1;
answers=vector<int>(3*n+1,-1);
answers[0] = answers[1] = answers[2] = 0;
answers[3] = weights[0];
answers[3*n] = accumulate(weights.begin(), weights.end(), (int)0);
state = -1;
solvedouble(4,3*n-1,weights);
for (int i=0; i<3*n+1; i++) {
//cout << "lefting " << i <<" " << optimal_k[i] <<" " << answers[i] << endl;
}
return answers;
}
vector<int> attractions;
int best = -INF;
long long findMaxAttraction(signed n, signed start, signed d, signed attraction[]) {
for (int i=0; i<n; i++) {
attractions.push_back(attraction[i]);
}
vector<int> right;
for (int i=start; i<n; i++) {
right.push_back(attractions[i]);
}
vector<int> left;
for (int i=start-1; i>=0; i--) {
left.push_back(attractions[i]);
}
vector<int> right_ans = solve_for_all_d(right.size(), right);
vector<int> left_ans = solve_for_all_d_double(left.size(), left);
while (right_ans.size()<=d+8) {
right_ans.push_back(right_ans.back());
}
while (left_ans.size()<=d+8) {
left_ans.push_back(left_ans.back());
}
for (int i=0; i<right_ans.size(); i++) {
//cout << "right " << i <<" " << right_ans[i] << endl;
}
for (int i=0; i<left_ans.size(); i++) {
//cout << "left " << i <<" " << left_ans[i] << endl;
}
//left then right
for (int days_left=0; days_left<=d+1; days_left++) {
int days_right = d-days_left+1;
if (days_right<0) continue;
int total_ans = left_ans[days_left]+right_ans[days_right];
best=max(best,total_ans);
}
start=n-start-1;
right.clear();
left.clear();
reverse(attractions.begin(), attractions.end());
for (int i=start; i<n; i++) {
right.push_back(attractions[i]);
}
for (int i=start-1; i>=0; i--) {
left.push_back(attractions[i]);
}
right_ans = solve_for_all_d(right.size(), right);
left_ans = solve_for_all_d_double(left.size(), left);
while (right_ans.size()<=d+8) {
right_ans.push_back(right_ans.back());
}
while (left_ans.size()<=d+8) {
left_ans.push_back(left_ans.back());
}
for (int i=0; i<right_ans.size(); i++) {
//cout << "right " << i <<" " << right_ans[i] << endl;
}
for (int i=0; i<left_ans.size(); i++) {
//cout << "left " << i <<" " << left_ans[i] << endl;
}
//left then right
for (int days_left=0; days_left<=d+1; days_left++) {
int days_right = d-days_left+1;
if (days_right<0) continue;
int total_ans = left_ans[days_left]+right_ans[days_right];
best=max(best,total_ans);
}
return best;
}