Submission #494692

#TimeUsernameProblemLanguageResultExecution timeMemory
494692wiwihoOlympiads (BOI19_olympiads)C++14
100 / 100
43 ms2904 KiB
#include <bits/stdc++.h>
#include <bits/extc++.h>

#define StarBurstStream ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);
#define iter(a) a.begin(), a.end()
#define riter(a) a.rbegin(), a.rend()
#define lsort(a) sort(iter(a))
#define gsort(a) sort(riter(a))
#define pb(a) push_back(a)
#define eb(a) emplace_back(a)
#define pf(a) push_front(a)
#define ef(a) emplace_front(a)
#define pob pop_back()
#define pof pop_front()
#define mp(a, b) make_pair(a, b)
#define F first
#define S second
#define mt make_tuple
#define gt(t, i) get<i>(t)
#define tomax(a, b) ((a) = max((a), (b)))
#define tomin(a, b) ((a) = min((a), (b)))
#define topos(a) ((a) = (((a) % MOD + MOD) % MOD))
#define uni(a) a.resize(unique(iter(a)) - a.begin())
#define printv(a, b) {bool pvaspace=false; \
for(auto pva : a){ \
    if(pvaspace) b << " "; pvaspace=true;\
    b << pva;\
}\
b << "\n";}

using namespace std;
using namespace __gnu_pbds;

typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;

using pii = pair<int, int>;
using pll = pair<ll, ll>;
using pdd = pair<ld, ld>;
using tiii = tuple<int, int, int>;

const ll MOD = 1000000007;
const ll MAX = 2147483647;

template<typename A, typename B>
ostream& operator<<(ostream& o, pair<A, B> p){
    return o << '(' << p.F << ',' << p.S << ')';
}

ll ifloor(ll a, ll b){
    if(b < 0) a *= -1, b *= -1;
    if(a < 0) return (a - b + 1) / b;
    else return a / b;
}

ll iceil(ll a, ll b){
    if(b < 0) a *= -1, b *= -1;
    if(a > 0) return (a + b - 1) / b;
    else return a / b;
}

struct discr{
    vector<int> t;
    int sz;
    discr(vector<int>& v): t(v){
        gsort(t);
        uni(t);
        sz = t.size();
    }
    int operator()(int x){
        return lower_bound(iter(t), x, greater<>()) - t.begin();
    }
    int operator[](int x){
        return t[x];
    }
};

int n, k, c;
vector<discr> d;
vector<int> ok;
vector<vector<int>> ex;
vector<vector<int>> a;
int cnt = 0;
vector<int> mx;
vector<bool> use;

void init(){
    a.resize(n + 1, vector<int>(k));
    mx.resize(k);
    use.resize(n + 1);
    ex.resize(k);
}

struct info{
    vector<int> v;
    int sum = 0;
    info(vector<int> v): v(v){
        for(int i = 0; i < k; i++) sum += d[i][v[i]];
    }
    bool operator<(const info& b) const{
        return sum == b.sum ? v < b.v : sum > b.sum;
    }
};

vector<int> arr;
set<vector<int>> owo;
void dfs(int now, vector<int>& v){
    if(now == k){
        vector<int> tv = arr;
        lsort(tv);
        //cerr << "try ";
        //printv(arr, cerr);
        //for(auto i : owo) printv(i, cerr);
        if(owo.find(tv) != owo.end()) return;
        owo.insert(tv);
        //cerr << "ok ";
        //printv(arr, cerr);
        cnt++;
        if(cnt == c){
            int sum = 0;
            for(int i = 0; i < k; i++) sum += mx[i];
            cout << sum << "\n";
            exit(0);
        }
        return;
    }
    
    if(mx[now] == d[now][v[now]]){
        for(int i : ok){
            if(use[i]) continue;
            vector<int> omx = mx;
            for(int t = 0; t < k; t++) mx[t] = max(mx[t], a[i][t]);
            use[i] = true;
            arr.eb(i);
            dfs(now + 1, v);
            arr.pob;
            use[i] = false;
            mx = omx;
        }
    }
    else{
        for(int i : ex[now]){
            if(use[i]) continue;
            vector<int> omx = mx;
            for(int t = 0; t < k; t++) mx[t] = max(mx[t], a[i][t]);
            use[i] = true;
            arr.eb(i);
            dfs(now + 1, v);
            arr.pob;
            use[i] = false;
            mx = omx;
        }
    }
}

void bf(int now, vector<int>& tmp, set<info>& pq){
    if(now == k){
        pq.insert(info(tmp));
        return;
    }
    for(int i = 0; i < d[now].sz; i++){
        tmp.eb(i);
        bf(now + 1, tmp, pq);
        tmp.pob;
    }
}

//mt19937 rnd(124124);
//uniform_int_distribution<int> ud(1, 10);

int main(){
    StarBurstStream
    //freopen("test.txt", "w", stderr);

    cin >> n >> k >> c;
    init();

    vector<vector<int>> td(k);
    for(int i = 1; i <= n; i++){
        for(int j = 0; j < k; j++){
            cin >> a[i][j];
            //a[i][j] = ud(rnd);
            td[j].eb(a[i][j]);
        }
    }
    for(int i = 0; i < k; i++){
        d.eb(td[i]);
    }
    //cerr << "ok\n";

    /*cerr << "discr\n";
    for(int i = 0; i < k; i++){
        cerr << i << "  ";
        printv(d[i].t, cerr);
    }*/

    set<info> pq;
    set<info> vst;
    pq.insert(info(vector<int>(k)));
    vst.insert(info(vector<int>(k)));

    /*set<info> all;
    vector<int> qq;
    bf(0, qq, all);*/
    
    while(!pq.empty()){
        info vi = *pq.begin();
        pq.erase(pq.begin());
        vector<int> v = vi.v;
        /*if(v != all.begin()->v){
            cerr << "OAO " << all.begin()->sum << "\n";
            printv(v, cerr);
            printv(all.begin()->v, cerr);

            for(auto& i : pq){
                cerr << " " << i.sum << " ";
                printv(i.v, cerr);
            }
            
            assert(false);
        }
        all.erase(all.begin());*/
        
        for(int i = 0; i < k; i++) ex[i].clear();
        ok.clear();
        for(int i = 1; i <= n; i++){
            for(int j = 0; j < k; j++){
                if(a[i][j] > d[j][v[j]]) goto nxt;
            }
            ok.eb(i);
            for(int j = 0; j < k; j++){
                if(a[i][j] == d[j][v[j]]) ex[j].eb(i);
            }
            nxt:;
        }
        bool oao = false;
        for(int i = 0; i < k; i++) if(ex[i].empty()) oao = true;
        if(ok.size() < k){
            //cerr << "reject ";
            //printv(v, cerr);
            continue;
        }

        if(!oao){
            //cerr << "solve ";
            //printv(v, cerr);
            //printv(ok, cerr);

            fill(iter(mx), -1);
            dfs(0, v);
        }

        for(int i = 0; i < k; i++){
            if(v[i] + 1 >= d[i].sz) continue;
            v[i]++;
            info t = v;
            if(vst.find(t) == vst.end()){
                vst.insert(t);
                pq.insert(t);
            }
            //cerr << "add ";
            //printv(v, cerr);
            v[i]--;
        }
    }
    //assert(false);

    return 0;
}

Compilation message (stderr)

olympiads.cpp: In function 'int main()':
olympiads.cpp:239:22: warning: comparison of integer expressions of different signedness: 'std::vector<int>::size_type' {aka 'long unsigned int'} and 'int' [-Wsign-compare]
  239 |         if(ok.size() < k){
      |            ~~~~~~~~~~^~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...