답안 #348414

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
348414 2021-01-15T01:43:48 Z 12tqian Olympiads (BOI19_olympiads) C++17
100 / 100
109 ms 2088 KB
#include <bits/stdc++.h>

using namespace std;

#define f1r(i, a, b) for (int i = a; i < b; i++)
#define f0r(i, a) f1r(i, 0, a)
#define trav(t, a) for (auto& t : a)

#define pb push_back
#define eb emplace_back
#define f first
#define s second
#define mp make_pair
#define sz(x) (int) (x).size()
#define all(x) (x).begin(), (x).end()\

typedef long long ll;
typedef vector<int> vi;
typedef pair<int, int> pi;
typedef vector<pi> vpi;
typedef vector<ll> vl;
const int N = 505;
const int K = 8;
ll C[N][K];
ll choose(ll x, ll y) {
    if (x<y) return 0;
    ll res = 1;
    f0r(j, y) res *= (x-j);
    f0r(j, y) res /= (j+1);
    return res;
}
int main() {
#ifndef LOCAL
    cin.tie(0)->sync_with_stdio(0);
#else
    freopen("file.in", "r", stdin);
    freopen("A.out", "w", stdout);
#endif
    f0r(i, N) f0r(j, K) C[i][j] = choose(i, j);
    int n, k, c; cin >> n >> k >> c;
    vector<vi> a(n, vi(k));
    f0r(i, n) {
        f0r(j, k) {
            cin >> a[i][j];
        }
    }
    vector<vi> ord(k);
    f0r(i, n) {
        f0r(j, k) {
            ord[j].eb(a[i][j]);
        }
    }
    f0r(i, k) {
        sort(all(ord[i]));
        ord[i].erase(unique(all(ord[i])), ord[i].end());
    }
    auto get = [&](vi v) -> ll {
        ll res = 0;
        f0r(i, k) res += ord[i][v[i]];
        return res;
    };
    auto pq_comp = [&](vi a, vi b) {
        return get(a) < get(b);
    };
    set<vi> enc;
    priority_queue<vi, vector<vi>, decltype(pq_comp)> pq(pq_comp);
    auto ad_helper = [&](vi v) {
        if (enc.count(v)) return;
        f0r(i, k) {
            if (v[i] < 0) {
                return;
            }
        }
        enc.insert(v);
        pq.push(v);
    };
    auto ad = [&](vi v) {
        f0r(i, k) {
            vi tmp = v;
            tmp[i]--;
            ad_helper(tmp);
        }
    };
    vi tmp;
    f0r(i, k) tmp.eb(sz(ord[i])-1);
    ad_helper(tmp);
    ll num = 0;
    vl cnt((1 << k));
    vector<vl> dp(k+1, vl((1 << k)));
    // how many things you have
    // what is your current mask
    // vi done;
    while (num < c) {
        vi cur = pq.top();
        // cout << get(cur) << endl;
        // done.pb(get(cur));
        pq.pop();
        cnt.assign((1 << k), 0);
        dp.assign(k+1, vl(1 << k));
        f0r(i, n) {
            bool ok = true;
            int mask = 0;
            f0r(j, k) {
                int val = ord[j][cur[j]];
                if (a[i][j] > val) {
                    ok = false;
                    break;  
                }
                if (a[i][j] == val) {
                    mask |= (1 << j);
                }
            }
            if (!ok) continue;
            cnt[mask]++;
        }
        dp[0][0] = 1;
        f0r(mask, (1 << k)) {
            for (int bmask = (1 << k) - 1; bmask >= 0; bmask--) {
                for (int x = k; x >= 0; x--) {
                    if (dp[x][bmask] == 0) continue;
                    f1r(y, 1, cnt[mask]+1) {
                        if (x+y > k) continue;
                        dp[x+y][bmask|mask] += dp[x][bmask] * C[cnt[mask]][y];
                    }
                }
            }
        }
        // int acc = 0;
        // f0r(i, n) { 
        //     f0r(j, i) {
        //         if (max(a[i][0], a[j][0]) == ord[0][cur[0]] && max(a[i][1], a[j][1]) == ord[1][cur[1]]) {
        //             acc++;
        //         }
        //     }
        // }
        // assert(acc == dp[k].back());
        // cout << dp[k].back() << " " << cnt[0] << " " << cnt[1] << " " << cnt[2] << " " << cnt[3] << endl;
        // cout << dp[1][0] << endl;
        // assert(dp[k][(1 << k) - 1] == cnt[0] * cnt[3] + cnt[3] * (cnt[3] - 1) / 2 + cnt[1] * cnt[2] + (cnt[1] + cnt[2]) * cnt[3]);
        num += dp[k][(1 << k) - 1];
        // num += acc;
        if (num >= c) {
            // sort(all(done));
            // done.erase(unique(all(done)), done.end());
            // reverse(all(done));
            // for (int x : done) cout << x << endl;
            ll ans = get(cur);
            cout << ans << '\n';
            return 0;
        }
        ad(cur);
    }
    return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 5 ms 492 KB Output is correct
2 Correct 4 ms 492 KB Output is correct
3 Correct 1 ms 364 KB Output is correct
4 Correct 1 ms 364 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 20 ms 620 KB Output is correct
2 Correct 47 ms 1132 KB Output is correct
3 Correct 64 ms 1260 KB Output is correct
4 Correct 65 ms 1260 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 364 KB Output is correct
2 Correct 1 ms 364 KB Output is correct
3 Correct 3 ms 492 KB Output is correct
4 Correct 3 ms 492 KB Output is correct
5 Correct 3 ms 492 KB Output is correct
6 Correct 10 ms 516 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 5 ms 492 KB Output is correct
2 Correct 4 ms 492 KB Output is correct
3 Correct 1 ms 364 KB Output is correct
4 Correct 1 ms 364 KB Output is correct
5 Correct 20 ms 620 KB Output is correct
6 Correct 47 ms 1132 KB Output is correct
7 Correct 64 ms 1260 KB Output is correct
8 Correct 65 ms 1260 KB Output is correct
9 Correct 1 ms 364 KB Output is correct
10 Correct 1 ms 364 KB Output is correct
11 Correct 3 ms 492 KB Output is correct
12 Correct 3 ms 492 KB Output is correct
13 Correct 3 ms 492 KB Output is correct
14 Correct 10 ms 516 KB Output is correct
15 Correct 103 ms 1260 KB Output is correct
16 Correct 29 ms 876 KB Output is correct
17 Correct 109 ms 2088 KB Output is correct
18 Correct 109 ms 1764 KB Output is correct
19 Correct 1 ms 512 KB Output is correct