Submission #1212355

#TimeUsernameProblemLanguageResultExecution timeMemory
1212355boxScales (IOI15_scales)C++20
76.49 / 100
89 ms5916 KiB
#include "scales.h"
#include <bits/stdc++.h>
using namespace std;

#define ar array
#define sz(v) int(v.size())
#define all(v) v.begin(), v.end()
typedef long long ll;

ar<int, 6> space[720];
int inv[720];
int res[720][6][6][6][9];

int at(int a, int i) {
    return space[a][i];
}
int qry_min(int a, int i, int j, int k) {
    int b = inv[a];
    return at(a, min({at(b, i), at(b, j), at(b, k)}));
}
int qry_max(int a, int i, int j, int k) {
    int b = inv[a];
    return at(a, max({at(b, i), at(b, j), at(b, k)}));
}
int qry_med(int a, int i, int j, int k) {
    return i ^ j ^ k ^ qry_min(a, i, j, k) ^ qry_max(a, i, j, k);
}
int qry_nl(int a, int i, int j, int k, int l) {
    int b = inv[a];
    if (at(b, qry_max(a, i, j, k)) < at(b, l))
        return qry_min(a, i, j, k);
    
    pair<int, int> mn{6, -1};
    if (at(b, i) > at(b, l))
        mn = min(mn, make_pair(at(b, i), i));
    if (at(b, j) > at(b, l))
        mn = min(mn, make_pair(at(b, j), j));
    if (at(b, k) > at(b, l))
        mn = min(mn, make_pair(at(b, k), k));
    return mn.second;
}

void init(int T) {
    ar<int, 6> a{0, 1, 2, 3, 4, 5};
    int c = 0;

    do {
        space[c++] = a;
    } while (next_permutation(all(a)));
    assert(c == 720);

    for (int i = 0; i < 720; i++) {
        ar<int, 6> b;
        for (int j = 0; j < 6; j++)
            b[space[i][j]] = j;
        inv[i] = find(space, space+720, b) - space;
    }

    for (int a = 0; a < 720; a++) {
        for (int i = 0; i < 6; i++) for (int j = 0; j < 6; j++) if (i != j) for (int k = 0; k < 6; k++) if (i != k && j != k) {
            res[a][i][j][k][6] = qry_min(a, i, j, k);
            res[a][i][j][k][7] = qry_max(a, i, j, k);
            res[a][i][j][k][8] = qry_med(a, i, j, k);

            for (int l = 0; l < 6; l++) if (i != l && j != l && k != l)
                res[a][i][j][k][l] = qry_nl(a, i, j, k, l);
        }
    }
}

int rec(vector<int> me) {
    if (sz(me) == 1)
        return me[0];

    pair<double, ar<int, 4>> mn{-1, {}};

    for (int i = 0; i < 6; i++) for (int j = 0; j < 6; j++) if (i != j) for (int k = 0; k < 6; k++) if (i != k && j != k) {
        for (int l = 0; l < 9; l++) if (i != l && j != l && k != l) {
            ar<int, 3> div{0, 0, 0};
            for (int a : me) {
                int p = res[a][i][j][k][l];
                if (p == i)
                    div[0]++;
                else if (p == j)
                    div[1]++;
                else if (p == k)
                    div[2]++;
                else
                    assert(false);
            }
            double e = 0;
            for (int x : div) if (x) {
                double p = double(x) / sz(me);
                e += -p * log2(p);
            }
            mn = max(mn, make_pair(e, ar<int, 4>{i, j, k, l}));
        }
    }

    auto [i, j, k, l] = mn.second;
    int p = [&]() {
        if (l < 6)
            return getNextLightest(i+1, j+1, k+1, l+1);
        else if (l == 6)
            return getLightest(i+1, j+1, k+1);
        else if (l == 7)
            return getHeaviest(i+1, j+1, k+1);
        else if (l == 8)
            return getMedian(i+1, j+1, k+1);
        else
            assert(false);
    }() - 1;

    vector<int> nxt;
    for (int a : me) if (res[a][i][j][k][l] == p)
        nxt.push_back(a);

    return rec(nxt);
}

void orderCoins() {
    vector<int> me(720);
    iota(all(me), 0);

    int i = rec(me);
    
    int ans[6];
    for (int j = 0; j < 6; j++)
        ans[j] = space[i][j]+1;
    answer(ans);
}
#Verdict Execution timeMemoryGrader output
Fetching results...