Submission #1261759

#TimeUsernameProblemLanguageResultExecution timeMemory
1261759medmdgFestival (IOI25_festival)C++20
100 / 100
88 ms23720 KiB
#include "festival.h"
#include<bits/stdc++.h>
#define mk(a,b) make_pair(a,b)
#define pb(a) push_back(a)
using namespace std;
typedef long long ll;
ll n;
vector<pair<ll,int>> t[4];
set<pair<ll,ll>> s;
ll a;
vector<ll> p;
const ll sz=31;
ll M=1e15;
int getP1(ll a){
    if(a<0)return 0;
    auto t=s.lower_bound(mk(a+1,-1));
    if(t==s.begin())return 0;
    t--;
    return (*t).second;
}
vector<int> sub3(){
    ll ma=getP1(a);
    ll nb=0;

    for(int i=0;i<t[1].size();i++){
        a-=t[1][i].first;
        if(a<1e17)
            a*=2;
        if(a<0)break;
        ll nma=getP1(a)+i+1;
        if(nma>ma){
            ma=nma;
            nb=i+1;
        }
    }
    vector<int> ans;
    for(int i=0;i<nb;i++){
        ans.push_back(t[1][i].second);
    }
    for(int i=nb;i<ma;i++){
        ans.push_back(t[0][i-nb].second);
    }
    return ans;
}

ll pick_best(ll a,ll v1,ll v2,ll v3){
    a=min(a,M);
    ll xa=a-2*v1,xb=2*a-3*v2,xc=3*a-4*v3;
    if(v1==-1){
        if(v2==-1)return 2;
        if(v3==-1)return 1;
        if(3*xb>=2*xc)return 1;
        return 2;
    }else if(v2==-1){
        if(v3==-1)return 0;
        if(3*xa>=xc)return 0;
        return 2;
    }
    else{
        if(v3==-1){
            if(2*xa>=xb) return 0;
            return 1;
        }
    }
    if(2*xa>=xb){
        if( 3*xa>=xc)return 0;
        return 2;
    }
    if(3*xb>=2*xc)return 1;
    return 2;
}
vector<int> Ans;
void increasing(){
    ll ta=a;
    vector<int> ans;
    ll i=0,j=0,k=0;
    while(i<t[1].size()||j<t[2].size()||k<t[3].size()){
        ll v1=-1,v2=-1,v3=-1;
        if(i<t[1].size())  v1=t[1][i].first;
        if(j<t[2].size())  v2=t[2][j].first;
        if(k<t[3].size())  v3=t[3][k].first;
        ll v=pick_best(ta,v1,v2,v3);
        if(v==0)ans.pb(t[1][i++].second);
        if(v==1)ans.pb(t[2][j++].second);
        if(v==2)ans.pb(t[3][k++].second);
        if(ta-p[ans.back()]<0){
            ans.pop_back();
            break;
        }
        if(ta>=M){continue;}
        if((ta-p[ans.back()])*(v+2)<ta){
            if(v==0)i--;
            if(v==1)j--;
            if(v==2)k--;
            ans.pop_back();
            break;
        }
        ta-=p[ans.back()];
        ta*=(v+2);
    }
    a=ta;
    Ans=ans;
    for(int z=0;z+i<t[1].size();z++)t[1][z]=t[1][z+i];
    for(int z=0;z+j<t[2].size();z++)t[2][z]=t[2][z+j];
    for(int z=0;z+k<t[3].size();z++)t[3][z]=t[3][z+k];
    t[1].resize(t[1].size()-i);
    t[2].resize(t[2].size()-j);
    t[3].resize(t[3].size()-k);

}

vector<int> sub6(){
    increasing();
    ll dp[sz][sz][sz];
    int pre[sz][sz][sz];
    ll ma=getP1(a);
    ll ci=0,cj=0,ck=0;
    dp[0][0][0]=a;
    for(int i=0;i<sz && i<=t[1].size();i++){
        for(int j=0;j<sz&& j<=t[2].size();j++){
            for(int k=0;k<sz&& k<=t[3].size();k++){
                pre[i][j][k]=-1;
                if(i+j+k>=sz)break;
                dp[i][j][k]=-1;
                if(i==j && j==k && k==0)dp[i][j][k]=a;
                ll v1=-1,v2=-1,v3=-1;
                if(i){
                    v1=(dp[i-1][j][k]-t[1][i-1].first)*2;
                    v1=min(v1,M);
                }
                if(j){
                    v2=(dp[i][j-1][k]-t[2][j-1].first)*3;
                    v2=min(v2,M);
                }
                if(k){
                    v3=(dp[i][j][k-1]-t[3][k-1].first)*4;
                    v3=min(v3,M);
                }
                if(v1>dp[i][j][k]){
                    dp[i][j][k]=v1;
                    pre[i][j][k]=0;
                }
                if(v2>dp[i][j][k]){
                    dp[i][j][k]=v2;
                    pre[i][j][k]=1;
                }
                if(v3>dp[i][j][k]){
                    dp[i][j][k]=v3;
                    pre[i][j][k]=2;
                }
                if(dp[i][j][k]<0)continue;
                ll val=i+j+k+getP1(dp[i][j][k]);
                if(val>ma){
                    ci=i,cj=j,ck=k;
                    ma=val;
                    continue;
                }
            }
        }
    }
    ll vrem=getP1(dp[ci][cj][ck]);
    vector<int> to;

    for(int i=0;i<vrem;i++){
        to.push_back(t[0][i].second);
    }
    while(1){
        if(pre[ci][cj][ck]==-1)break;
        if(pre[ci][cj][ck]==0){
            to.push_back(t[1][--ci].second);
            continue;
        }
        if(pre[ci][cj][ck]==1){
            to.push_back(t[2][--cj].second);
            continue;
        }
        if(pre[ci][cj][ck]==2){
            to.push_back(t[3][--ck].second);
            continue;
        }
    }
    for(int i=to.size()-1;i>=0;i--)Ans.push_back(to[i]);
    return Ans;
}
vector<int> max_coupons(int A, vector<int> P, vector<int> T) {
    for(int i=0;i<4;i++)t[i].clear();
    for(int i=0;i<T.size();i++){
        t[T[i]-1].push_back(mk(P[i],i));
    }
    for(int i=0;i<4;i++)  sort(t[i].begin(),t[i].end());
    s.clear();
    s.insert(mk(0,0));
    ll ps=0;
    for(int i=0;i<t[0].size();i++){ ps+=t[0][i].first;s.insert(mk(ps,i+1));}
    n=P.size();
    a=A;
    p.clear();
    Ans.clear();
    for(auto x:P)p.push_back(x);
    return sub6();
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...