#include "scales.h"
#include <bits/stdc++.h>
using namespace std;
#define ar array
#define sz(v) int(v.size())
#define all(v) v.begin(), v.end()
typedef long long ll;
ar<int, 6> space[720];
int inv[720];
int res[720][6][6][6][9];
int at(int a, int i) {
return space[a][i];
}
int qry_min(int a, int i, int j, int k) {
int b = inv[a];
return at(a, min({at(b, i), at(b, j), at(b, k)}));
}
int qry_max(int a, int i, int j, int k) {
int b = inv[a];
return at(a, max({at(b, i), at(b, j), at(b, k)}));
}
int qry_med(int a, int i, int j, int k) {
return i ^ j ^ k ^ qry_min(a, i, j, k) ^ qry_max(a, i, j, k);
}
int qry_nl(int a, int i, int j, int k, int l) {
int b = inv[a];
if (at(b, qry_max(a, i, j, k)) < at(b, l))
return qry_min(a, i, j, k);
pair<int, int> mn{6, -1};
if (at(b, i) > at(b, l))
mn = min(mn, make_pair(at(b, i), i));
if (at(b, j) > at(b, l))
mn = min(mn, make_pair(at(b, j), j));
if (at(b, k) > at(b, l))
mn = min(mn, make_pair(at(b, k), k));
return mn.second;
}
void init(int T) {
ar<int, 6> a{0, 1, 2, 3, 4, 5};
int c = 0;
do {
space[c++] = a;
} while (next_permutation(all(a)));
assert(c == 720);
for (int i = 0; i < 720; i++) {
ar<int, 6> b;
for (int j = 0; j < 6; j++)
b[space[i][j]] = j;
inv[i] = find(space, space+720, b) - space;
}
for (int a = 0; a < 720; a++) {
for (int i = 0; i < 6; i++) for (int j = 0; j < 6; j++) if (i != j) for (int k = 0; k < 6; k++) if (i != k && j != k) {
res[a][i][j][k][6] = qry_min(a, i, j, k);
res[a][i][j][k][7] = qry_max(a, i, j, k);
res[a][i][j][k][8] = qry_med(a, i, j, k);
for (int l = 0; l < 6; l++) if (i != l && j != l && k != l)
res[a][i][j][k][l] = qry_nl(a, i, j, k, l);
}
}
}
int rec(vector<int> me) {
if (sz(me) == 1)
return me[0];
pair<double, ar<int, 4>> mn{-1, {}};
for (int i = 0; i < 6; i++) for (int j = 0; j < 6; j++) if (i != j) for (int k = 0; k < 6; k++) if (i != k && j != k) {
for (int l = 0; l < 9; l++) if (i != l && j != l && k != l) {
ar<int, 3> div{0, 0, 0};
for (int a : me) {
int p = res[a][i][j][k][l];
if (p == i)
div[0]++;
else if (p == j)
div[1]++;
else if (p == k)
div[2]++;
else
assert(false);
}
double e = 0;
for (int x : div) if (x) {
double p = double(x) / sz(me);
e += -p * log2(p);
}
mn = max(mn, make_pair(e, ar<int, 4>{i, j, k, l}));
}
}
auto [i, j, k, l] = mn.second;
int p = [&]() {
if (l < 6)
return getNextLightest(i+1, j+1, k+1, l+1);
else if (l == 6)
return getLightest(i+1, j+1, k+1);
else if (l == 7)
return getHeaviest(i+1, j+1, k+1);
else if (l == 8)
return getMedian(i+1, j+1, k+1);
else
assert(false);
}() - 1;
vector<int> nxt;
for (int a : me) if (res[a][i][j][k][l] == p)
nxt.push_back(a);
return rec(nxt);
}
void orderCoins() {
vector<int> me(720);
iota(all(me), 0);
int i = rec(me);
int ans[6];
for (int j = 0; j < 6; j++)
ans[j] = space[i][j]+1;
answer(ans);
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |