# | 제출 시각 | 아이디 | 문제 | 언어 | 결과 | 실행 시간 | 메모리 |
---|---|---|---|---|---|---|---|
429098 | timmyfeng | 저울 (IOI15_scales) | C++17 | 0 ms | 0 KiB |
이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
using namespace std;
#include "scales.h"
const int N = 6, M = 720, P[] = {1, 3, 9, 27, 81, 243, 729};
map<vector<int>, array<int, 4>> compare;
vector<int> perm[M], inv[M];
int ans[N];
void sort(vector<int> &p, int &i, int &j, int &k) {
if (p[i] > p[j]) {
swap(i, j);
}
if (p[j] > p[k]) {
swap(j, k);
}
if (p[i] > p[j]) {
swap(i, j);
}
}
int lightest(vector<int> &p, int i, int j, int k) {
sort(p, i, j, k);
return i;
}
int heaviest(vector<int> &p, int i, int j, int k) {
sort(p, i, j, k);
return k;
}
int median(vector<int> &p, int i, int j, int k) {
sort(p, i, j, k);
return j;
}
int next(vector<int> &p, int i, int j, int k, int l) {
sort(p, i, j, k);
if (p[l] < p[i]) {
return i;
} else if (p[l] < p[j]) {
return j;
} else if (p[l] < p[k]) {
return k;
} else {
return i;
}
}
bool solve(vector<int> mask, int remain) {
if (mask.size() <= 1) {
return true;
} else if ((int) mask.size() > P[remain]) {
return false;
} else if (compare.count(mask) > 0) {
return true;
}
vector<int> one, two, three;
for (int i = 0; i < N; ++i) {
for (int j = i + 1; j < N; ++j) {
for (int k = j + 1; k < N; ++k) {
one.clear(), two.clear(), three.clear();
for (auto u : mask) {
int x = lightest(inv[u], i, j, k);
(x == i ? one : (x == j ? two : three)).push_back(u);
}
if (solve(one, remain - 1) &&
solve(two, remain - 1) &&
solve(three, remain - 1)) {
compare[mask] = {i, j, k, -1};
return true;
}
one.clear(), two.clear(), three.clear();
for (auto u : mask) {
int x = heaviest(inv[u], i, j, k);
(x == i ? one : (x == j ? two : three)).push_back(u);
}
if (solve(one, remain - 1) &&
solve(two, remain - 1) &&
solve(three, remain - 1)) {
compare[mask] = {i, j, k, -2};
return true;
}
one.clear(), two.clear(), three.clear();
for (auto u : mask) {
int x = median(inv[u], i, j, k);
(x == i ? one : (x == j ? two : three)).push_back(u);
}
if (solve(one, remain - 1) &&
solve(two, remain - 1) &&
solve(three, remain - 1)) {
compare[mask] = {i, j, k, -3};
return true;
}
for (int l = 0; l < N; ++l) {
if (l != i && l != j && l != k) {
one.clear(), two.clear(), three.clear();
for (auto u : mask) {
int x = next(inv[u], i, j, k, l);
(x == i ? one : (x == j ? two : three)).push_back(u);
}
if (solve(one, remain - 1) &&
solve(two, remain - 1) &&
solve(three, remain - 1)) {
compare[mask] = {i, j, k, l};
return true;
}
}
}
}
}
}
return false;
}
void init(int t) {
vector<int> temp(N);
iota(temp.begin(), temp.end(), 0);
for (int i = 0; i < M; ++i) {
perm[i] = temp;
inv[i].resize(6);
for (int j = 0; j < N; ++j) {
inv[i][perm[i][j]] = j;
}
next_permutation(temp.begin(), temp.end());
}
vector<int> mask(M);
iota(mask.begin(), mask.end(), 0);
solve(mask, 6);
}