Submission #1090022

#TimeUsernameProblemLanguageResultExecution timeMemory
1090022onlk97Carnival Tickets (IOI20_tickets)C++14
67 / 100
556 ms114588 KiB
#include "tickets.h"
#include <vector>
#include <bits/stdc++.h>
using namespace std;
long long dp[300][10000],bac[300][10000],ps[100][100],bps[100][100];
long long find_maximum(int k,vector <vector <int> > x){
    int n=x.size(),m=x[0].size();
    vector <vector <int> > op(n,vector <int>(m,-1));
    int chkmx=0;
    for (int i=0; i<n; i++) for (int j=0; j<m; j++) chkmx=max(chkmx,x[i][j]);
    long long ans=0;
    if (k==1){
        int idxmn[n],idxmx[n];
        for (int i=0; i<n; i++){
            idxmn[i]=min_element(x[i].begin(),x[i].end())-x[i].begin();
            idxmx[i]=max_element(x[i].begin(),x[i].end())-x[i].begin();
        }
        long long dp[n+1][n+1],bac[n+1][n+1];
        for (int i=0; i<=n; i++){
            for (int j=0; j<=n; j++) dp[i][j]=-1e18;
        }
        dp[0][0]=0;
        for (int i=0; i<n; i++){
            for (int j=0; j<=n/2; j++){
                if (dp[i][j]-x[i][idxmn[i]]>=dp[i+1][j]){
                    dp[i+1][j]=dp[i][j]-x[i][idxmn[i]];
                    bac[i+1][j]=0;
                }
                if (dp[i][j]+x[i][idxmx[i]]>=dp[i+1][j+1]){
                    dp[i+1][j+1]=dp[i][j]+x[i][idxmx[i]];
                    bac[i+1][j+1]=1;
                }
            }
        }
        ans=dp[n][n/2];
        int cur=n/2;
        for (int i=n; i>0; i--){
            if (bac[i][cur]){
                op[i-1][idxmx[i-1]]=0;
                cur--;
            } else {
                op[i-1][idxmn[i-1]]=0;
            }
        }
    } else if (chkmx<=1){
        int bal[n];
        for (int i=0; i<n; i++){
            bal[i]=0;
            for (int j=0; j<m; j++){
                if (!x[i][j]) bal[i]--;
                else bal[i]++;
            }
        }
        int fr[n],ba[n];
        for (int i=0; i<n; i++) fr[i]=0;
        for (int i=0; i<n; i++) ba[i]=m-1;
        for (int i=0; i<k; i++){
            vector <pair <int,int> > v;
            for (int j=0; j<n; j++) v.push_back({bal[j],j});
            sort(v.begin(),v.end());
            for (int j=0; j<n; j++){
                if (j<n/2){
                    op[v[j].second][fr[v[j].second]]=i;
                    if (!x[v[j].second][fr[v[j].second]]) bal[v[j].second]++;
                    else bal[v[j].second]--;
                    ans-=x[v[j].second][fr[v[j].second]];
                    fr[v[j].second]++;
                } else {
                    op[v[j].second][ba[v[j].second]]=i;
                    if (x[v[j].second][ba[v[j].second]]) bal[v[j].second]--;
                    else bal[v[j].second]++;
                    ans+=x[v[j].second][ba[v[j].second]];
                    ba[v[j].second]--;
                }
            }
        }
    } else if (k==m){
        vector <pair <long long,pair <int,int> > > v;
        for (int i=0; i<n; i++) for (int j=0; j<m; j++) v.push_back({x[i][j],{i,j}});
        sort(v.begin(),v.end());
        for (int i=0; i<n*m; i++){
            if (i<n*m/2) ans-=v[i].first;
            else ans+=v[i].first;
        }
        bool is[n][m];
        for (int i=0; i<n; i++) for (int j=0; j<m; j++) is[i][j]=0;
        for (int i=n*m/2; i<n*m; i++) is[v[i].second.first][v[i].second.second]=1;
        int idx=0;
        for (int i=0; i<n; i++){
            for (int j=0; j<m; j++){
                if (is[i][j]){
                    op[i][j]=idx;
                    idx++;
                    if (idx>=m) idx-=m;
                }
            }
        }
        bool hv[m];
        for (int i=0; i<n; i++){
            for (int j=0; j<m; j++) hv[j]=0;
            for (int j=0; j<m; j++) if (op[i][j]>=0) hv[op[i][j]]=1;
            int ptr=0;
            for (int j=0; j<m; j++){
                if (op[i][j]==-1){
                    while (hv[ptr]) ptr++;
                    op[i][j]=ptr;
                    ptr++;
                }
            }
        }
    } else {
        for (int i=0; i<=n; i++){
            for (int j=0; j<=n*k/2; j++) dp[i][j]=-1e18;
        }
        dp[0][0]=0;
        vector <pair <int,int> > sorted_row[n];
        for (int i=0; i<n; i++){
            for (int j=0; j<m; j++) sorted_row[i].push_back({x[i][j],j});
            sort(sorted_row[i].begin(),sorted_row[i].end());
        }
        for (int i=0; i<n; i++){
            ps[i][0]=0;
            for (int j=1; j<=m; j++) ps[i][j]=ps[i][j-1]+sorted_row[i][j-1].first;
            bps[i][0]=0;
            for (int j=1; j<=m; j++) bps[i][j]=bps[i][j-1]+sorted_row[i][m-j].first;
        }
        for (int i=0; i<n; i++){
            for (int j=0; j<=n*k/2; j++){
                for (int l=0; l<=k; l++){
                    long long nval=dp[i][j]-ps[i][k-l]+bps[i][l];
                    if (nval>=dp[i+1][j+l]){
                        dp[i+1][j+l]=nval;
                        bac[i+1][j+l]=l;
                    }
                }
            }
        }
        ans=dp[n][n*k/2];
        bool is_pos[n][m],is_neg[n][m];
        for (int i=0; i<n; i++){
            for (int j=0; j<m; j++){
                is_pos[i][j]=0; is_neg[i][j]=0;
            }
        }
        int cur=n*k/2;
        for (int i=n; i>0; i--){
            int l=bac[i][cur];
            for (int j=0; j<l; j++) is_pos[i-1][sorted_row[i-1][m-1-j].second]=1;
            for (int j=0; j<k-l; j++) is_neg[i-1][sorted_row[i-1][j].second]=1;
            //cout<<"done pos "<<i<<": "; for (int j=0; j<l; j++) cout<<sorted_row[i-1][m-1-j].second<<' '; cout<<'\n';
            //cout<<"done neg "<<i<<": "; for (int j=0; j<k-l; j++) cout<<sorted_row[i-1][j].second<<' '; cout<<'\n';
            cur-=l;
        }
        int idx=0;
        bool hv[k];
        for (int i=0; i<n; i++){
            for (int j=0; j<m; j++){
                if (is_pos[i][j]){
                    op[i][j]=idx;
                    idx++;
                    if (idx>=k) idx-=k;
                }
            }
            for (int j=0; j<k; j++) hv[j]=0;
            for (int j=0; j<m; j++) if (op[i][j]>=0) hv[op[i][j]]=1;
            int ptr=0;
            for (int j=0; j<m; j++){
                //cout<<i<<' '<<j<<": "<<is_neg[i][j]<<'\n';
                if (is_neg[i][j]){
                    while (hv[ptr]) ptr++;
                    //cout<<"add "<<i<<' '<<j<<": "<<ptr<<'\n';
                    op[i][j]=ptr;
                    ptr++;
                }
            }
        }
    }
    allocate_tickets(op);
    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...