#include "holiday.h"
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pi pair<int, int>
#define pl pair<ll, ll>
#define vi vector<int>
#define vl vector<ll>
#define fi first
#define se second
#define pb push_back
#define all(x) (x).begin(),(x).end()
struct segTree {
struct node {
int act=1;
ll val=0;
ll tval=0;
node(int _act=1, ll _val=0): act(_act), val(_val), tval(_val) {}
};
node unite(node a, node b) {
return {a.act+b.act,a.val+b.val};
}
vector<node> nodes;
int sze;
void init(int n, vi& val) {
sze=n;
nodes.resize(4*sze);
build(1,0,n-1,val);
}
void build(int v, int tl, int tr, vi& val) {
if (tl==tr) {
nodes[v]=node(1,val[tl]);
return;
}
int tm=tl+(tr-tl)/2;
build(2*v,tl,tm,val);
build(2*v+1,tm+1,tr,val);
nodes[v]=unite(nodes[2*v],nodes[2*v+1]);
}
void chstate(int v, int tl, int tr, int ind, int state) {
if (ind<tl || tr<ind) {
return;
}
if (tl==tr) {
nodes[v].act=state;
nodes[v].val=nodes[v].tval*state;
return;
}
int tm=tl+(tr-tl)/2;
chstate(2*v,tl,tm,ind,state);
chstate(2*v+1,tm+1,tr,ind,state);
nodes[v]=unite(nodes[2*v],nodes[2*v+1]);
}
void chstate(int ind, int state) {
chstate(1,0,sze-1,ind,state);
}
node get(int v, int tl, int tr, int x) {
if (nodes[v].act<=x) {
return nodes[v];
}
int tm=tl+(tr-tl)/2;
if (nodes[2*v].act>=x) {
return get(2*v,tl,tm,x);
}
return unite(nodes[2*v],get(2*v+1,tm+1,tr,x-nodes[2*v].act));
}
node get(int x) {
if (x<=0) {
return node(0,0);
}
return get(1,0,sze-1,x);
}
};
const int maxn=1e5+10;
vi inds(maxn);
vi getind(maxn);
segTree dat;
int strt;
vector<pl> ans1(3*maxn),ans2(3*maxn);
pl findans(int d, int tl, int tr, int mult) {
for (int i=tl; i<=tr; i++) {
dat.chstate(getind[i],1);
}
ll ans=-1;
int ind=-1;
for (int i=tr; i>=tl; i--) {
ll t=dat.get(d-mult*(i-strt)).val;
dat.chstate(getind[i],0);
if (t>=ans) {
ans=t;
ind=i;
}
}
return {ans,ind};
}
void div(int l, int r, int tl, int tr, int mult) {
if (r<l) {
return;
}
if (l==r) {
ans2[l]=findans(l,tl,tr,mult);
return;
}
int m=l+(r-l)/2;
ans2[m]=findans(m,tl,tr,mult);
div(l,m-1,tl,ans2[m].se,mult);
for (int i=tl; i<=ans2[m].se; i++) {
dat.chstate(getind[i],1);
}
div(m+1,r,ans2[m].se,tr,mult);
}
bool other=1;
ll findMaxAttraction(int n, int _strt, int d, int att[]) {
strt=_strt;
iota(inds.begin(),inds.begin()+n,0);
sort(inds.begin(),inds.begin()+n,[&](int a, int b){return att[a]>att[b];});
for (int i=0; i<n; i++) {
getind[inds[i]]=i;
}
vi val(att,att+n);
sort(all(val),[](int a, int b){return a>b;});
dat.init(n,val);
for (int i=0; i<strt; i++) {
dat.chstate(getind[i],0);
}
div(0,d,strt,n-1,2);
if (strt==0) {
ll ans=ans2[d].fi;
if (other) {
other=0;
reverse(att,att+n);
strt=n-1-strt;
ans=max(ans,findMaxAttraction(n,strt,d,att));
}
return ans;
}
swap(ans1,ans2);
reverse(att,att+n);
sort(inds.begin(),inds.begin()+n,[&](int a, int b){return att[a]>att[b];});
strt=n-1-strt+1;
for (int i=0; i<n; i++) {
getind[inds[i]]=i;
}
dat.init(n,val);
for (int i=0; i<strt; i++) {
dat.chstate(getind[i],0);
}
div(0,d,strt,n-1,1);
strt=n-strt;
ll ans=0;
for (int i=0; i<=d; i++) {
ll t=ans1[i].fi+ans2[max(d-1-i,0)].fi;
ans=max(ans,t);
}
if (other) {
other=0;
strt=n-1-strt;
ans=max(ans,findMaxAttraction(n,strt,d,att));
}
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |