Submission #1078082

#TimeUsernameProblemLanguageResultExecution timeMemory
1078082anangoCarnival Tickets (IOI20_tickets)C++17
100 / 100
1175 ms190640 KiB
#include "tickets.h"
#include <vector>
#include <bits/stdc++.h>
using namespace std;
#define int long long

vector<vector<signed>> intify(vector<vector<int>> res) {
    vector<vector<signed>> ans(res.size(),vector<signed>(res[0].size()));
    for (int i=0; i<res.size(); i++) {
        for (int j=0; j<res[0].size(); j++) {
            ans[i][j] = res[i][j];
        }
    }
    return ans;
}

long long find_maximum(signed k, std::vector<std::vector<signed>> x) {
    //need to fill the grid with + and -
    //exactly k of these in each row
    //such that the total sum of + minus total sum of - is maximised
    //so if we let the possible moves for each row be a set S, each element is {a,b}
    //a is the increase in total sum of + minus total sum of -
    //and b is the total change to the balance (number of + minus number of -)
    //then effectively in each row we need to process these moves in a dp
    //like if we let dp[i][balance] be the maximum answer after doing the first i rows with balance + minus - count
    //the balance can be upto nk/2 = O(n^2), so O(n^3) states and thusforely O(n^4) time complexity
    //passes for 63
    //try optimise to O(n^3)
    //randomise order of rows processed and hope the balance value doesn't become large in the optimal sol?
    //actually quite reasonable
    //if we consider worst case where everything is extreme so it's +-k each time
    //then by random walk theory the EV of |balance| is like k*sqrt(n)
    //so it goes from n^4 to n^3.5... probably not enough to pass

    //what if instead we bruteforce over the split point where stuff goes from minus to plus
    //doesn't help
    //what about starting by taking all plus (k pluses per row, take largest)
    //and slowly converting to minus, losing the minimal each time
    //not sure how to prove it works but whatever just implement
    //probably works because the costs in one row are increasing

    

	int n = x.size();
	int m = x[0].size();
    vector<vector<int>> answer(n,vector<int>(m,-1));
    vector<vector<int>> increasing_order(n,vector<int>(m)); 
    //the indices sorted in increasing order of value
    for (int i=0; i<n; i++) {
        iota(increasing_order[i].begin(), increasing_order[i].end(), (int)0);
        //sort(increasing_order[i].begin(), increasing_order[i].end(), [&](const int i1, const int i2) {
        //    return x[i][i1]<x[i][i2];
        //});
    }
    vector<int> opsdone(n,0); //num of ops done in the ith row so far
    int cursum = 0;
    priority_queue<pair<int,int>> deltas; //with cost of -a, we can apply one greedy operation to row i
    for (int i=0; i<n; i++) {
        for (int j=m-k; j<m; j++) {
            //cout << i <<" " << j << endl;
            cursum+=x[i][increasing_order[i][j]];
            answer[i][j] = -2;
        }
        pair<int,int> p = {-x[i][m-k]-x[i][0],i};
        //cout << "start " << i <<" " << p.first <<" " << p.second << endl;
        deltas.push(p);
    }
    int ctreq = n*k/2;
    while (ctreq>0) {
        assert(deltas.size());
        pair<int,int> p = deltas.top(); deltas.pop(); //smallest cost
        //cout << "doing1 " << p.first <<" " << p.second << endl;
        int cost = -p.first;
        int ind = p.second;
        cursum-=cost;
        answer[ind][increasing_order[ind][m-k+opsdone[ind]]] = -1;
        answer[ind][increasing_order[ind][opsdone[ind]]] = -3;
        opsdone[ind]++;
        //now we need to remove index n-k+1+opsdone[ind] and re-add index opsdone[ind]
        if (opsdone[ind]<k) {
            //cout << "costing " << -x[ind][increasing_order[ind][m-k+opsdone[ind]]]-x[ind][increasing_order[ind][opsdone[ind]]] <<" " << opsdone[ind] << endl;
            deltas.push({-x[ind][increasing_order[ind][m-k+opsdone[ind]]]-x[ind][increasing_order[ind][opsdone[ind]]],ind});
        }
        ctreq--;
        //cout << "doing2 " << ctreq <<" " << cursum << endl;
    }
    int m2ct = 0;
    int m3ct = 0;
    for (int i=0; i<n; i++) {
        for (int j=0; j<m; j++) {
            //cout << answer[i][j] <<" ";
            m2ct+=answer[i][j]==-2;
            m3ct+=answer[i][j]==-3;
        }
        //cout << endl;
    }
    assert(m2ct==n*k/2 && m3ct==n*k/2);
    

	/*vector<vector<int>> elems;
    for (int i=0; i<n; i++) {
        for (int j=0; j<m; j++) {
            elems.push_back({x[i][j],i,j});
        }
    }
    sort(elems.begin(), elems.end());
    int mid = n*m/2;
    for (int i=0; i<mid; i++) {
        answer[elems[i][1]][elems[i][2]] = -2; //minus
    }*/
    int sol = 0;
    for (int i=0; i<n; i++) {
        for (int j=0; j<m; j++) {
            if (answer[i][j]==-2) {
                sol+=x[i][j];
            }
            else if (answer[i][j]==-3) {
                sol-=x[i][j];
            }
        }
    }
    //-2 is plus, -3 is minus
    vector<set<int>> minus_rem(n); vector<set<int>> plus_rem(n);
    for (int i=0; i<n; i++) {
        for (int j=0; j<m; j++) {
            if (answer[i][j]==-2) {
                plus_rem[i].insert(j);
            }
            else if (answer[i][j]==-3) {
                minus_rem[i].insert(j);
            }
        }
    }
    for (int op=0; op<k; op++) {
        int balance = 0;
        set<int> rem;
        for (int i=0; i<n; i++) {
            if (!plus_rem[i].size()) {
                answer[i][*minus_rem[i].begin()] = op;
                minus_rem[i].erase(minus_rem[i].begin());
                balance--;
            }
            else if (!minus_rem[i].size()) {
                answer[i][*plus_rem[i].begin()] = op;
                plus_rem[i].erase(plus_rem[i].begin());
                balance++;
            }
            else {
                rem.insert(i);
            }
        }
        for (int i:rem) {
            if (balance>=0) {
                answer[i][*minus_rem[i].begin()] = op;
                minus_rem[i].erase(minus_rem[i].begin());
                balance--;
            }
            else {
                answer[i][*plus_rem[i].begin()] = op;
                plus_rem[i].erase(plus_rem[i].begin());
                balance++;
            }
        }
    }

	allocate_tickets(intify(answer));
	return sol;
}

Compilation message (stderr)

tickets.cpp: In function 'std::vector<std::vector<int> > intify(std::vector<std::vector<long long int> >)':
tickets.cpp:9:20: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<std::vector<long long int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
    9 |     for (int i=0; i<res.size(); i++) {
      |                   ~^~~~~~~~~~~
tickets.cpp:10:24: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   10 |         for (int j=0; j<res[0].size(); j++) {
      |                       ~^~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...