#include "mushrooms.h"
#include <bits/stdc++.h>
using namespace std;
const int GIVEUP = 0;
const int COUNT = 1;
const int SWAP = 2;
const int MOVE = 3;
const int INF = 1 << 30;
using group = vector<int>;
using distrib = pair<string,vector<int>>;
using movement = pair<distrib,map<int,distrib>>;
template<class T>
void sort_uniq(vector<T>& v) {
sort(v.begin(), v.end());
v.erase(unique(v.begin(), v.end()), v.end());
}
vector<movement> moves = {
// A....
{{"AWWWW",{0,1,2,3,4}}, {
{0, {"AAAAA",{0,1,2,3,4}}},
{1, {"A1B", {0,1,2,3,4}}},
{2, {"A2A", {0,1,2,3,4}}},
{3, {"A3B", {0,1,2,3,4}}},
{4, {"ABABA",{0,1,2,3,4}}},
}},
// .A...A
{{"WAWWWA",{0,1,2,3,4,5}}, {
{0, {"AAAAAA",{0,1,2,3,4,5}}},
{1, {"BAAAAA",{0,1,2,3,4,5}}},
{2, {"AA2A", {0,1,2,3,4,5}}},
{3, {"BA2A", {0,1,2,3,4,5}}},
{4, {"AABABA",{0,1,2,3,4,5}}},
{5, {"BABABA",{0,1,2,3,4,5}}},
}},
// .A.A.A
{{"WAWAWA",{0,1,2,3,4,5}}, {
{0, {"AAAAAA",{0,1,2,3,4,5}}},
{1, {"BAAAAA",{0,1,2,3,4,5}}},
{2, {"AA0AA", {0,1,2,4,3,5}}},
{3, {"BA0AA", {0,1,2,4,3,5}}},
{4, {"AABABA",{0,1,2,3,4,5}}},
{5, {"BABABA",{0,1,2,3,4,5}}},
}},
// .A...B
{{"WAWWWB",{0,1,2,3,4,5}}, {
{1, {"AA1B",{0,1,2,3,4,5}}},
{2, {"BA1B",{0,1,2,3,4,5}}},
{3, {"AA3B",{0,1,2,3,4,5}}},
{4, {"BA3B",{0,1,2,3,4,5}}},
}},
// 0A...B
{{"0AWWWB",{0,2,3,4,5,6}}, {
{1, {"AA1BB",{0,2,3,4,5,6,1}}},
{2, {"BA1BA",{0,2,3,4,5,6,1}}},
{3, {"AA3BB",{0,2,3,4,5,6,1}}},
{4, {"BA3BA",{0,2,3,4,5,6,1}}},
}},
// .A0A1A2A
{{"WA1AAA",{0,1,2,5,3,6,4,7}}, {
{0, {"AAAAAAAA",{0,1,2,5,3,6,4,7}}},
{1, {"BAAAAAAA",{0,1,2,5,3,6,4,7}}},
{2, {"AAAAAABA",{0,1,2,5,3,6,4,7}}},
{3, {"BAAAAABA",{0,1,2,5,3,6,4,7}}},
{4, {"AAAABABA",{0,1,2,5,3,6,4,7}}},
{5, {"BAAABABA",{0,1,2,5,3,6,4,7}}},
{6, {"AABABABA",{0,1,2,5,3,6,4,7}}},
{7, {"BABABABA",{0,1,2,5,3,6,4,7}}},
}},
// 0A1A
{{"2AA",{0,3,1,4}}, {
{0, {"AAAAB",{0,3,1,4,2}}},
{1, {"BAAAA",{0,3,1,4,2}}},
{2, {"AABAW",{0,3,1,4,2}}},
{3, {"BABAW",{0,3,1,4,2}}},
}},
// .A0A2AB1B
{{"WA3AABB",{0,1,2,5,4,6,7,3,8}}, {
{1, {"AAAAAABBB",{0,1,2,5,4,6,7,3,8}}},
{2, {"BAAAAABBB",{0,1,2,5,4,6,7,3,8}}},
{3, {"AABAAABBB",{0,1,2,5,4,6,7,3,8}}},
{4, {"BABAAABBB",{0,1,2,5,4,6,7,3,8}}},
{5, {"AABAAABAB",{0,1,2,5,4,6,7,3,8}}},
{6, {"BABAAABAB",{0,1,2,5,4,6,7,3,8}}},
{7, {"AABABABAB",{0,1,2,5,4,6,7,3,8}}},
{8, {"BABABABAB",{0,1,2,5,4,6,7,3,8}}},
}},
};
const string group_labels = "AB0123W";
const int groupc = group_labels.size();
const vector<int> group_sizes = {1, 1, 2, 3, 3, 3, 1};
int group_index[256];
vector<vector<group>> groups(7);
vector<group>& a = groups[0];
vector<group>& b = groups[1];
vector<group>& g0 = groups[2];
vector<group>& g1 = groups[3];
vector<group>& g2 = groups[4];
vector<group>& g3 = groups[5];
vector<group>& w = groups[6];
char A, B;
int cA, cB, q;
void init_groups() {
A = 'A', B = 'B';
cA = cB = q = 0;
for (vector<group>& g : groups) g.clear();
for (int i = 0; i < groupc; i++) group_index[group_labels[i]] = i;
}
int query(const vector<int>& que) {
assert(que.size() > 0);
return q++, use_machine(que);
}
void do_move(int move) {
vector<int> que;
if (move == COUNT) {
int take = min(w.size(), a.size());
for (int i = 0; i < take; i++) {
que.push_back(w.back().back()); w.pop_back();
que.insert(que.end(), a[i].begin(), a[i].end());
}
int res = query(que);
groups[res & 1].emplace_back(1, que.front());
res >>= 1;
assert(0 <= res && res < take);
cA += take - 1 - res;
cB += res;
} else if (move == SWAP) {
assert(g2.empty());
swap(A, B);
swap(cA, cB);
swap(a, b);
for (group& vec : g1) reverse(vec.begin(), vec.end());
for (group& vec : g3) reverse(vec.begin(), vec.end());
} else {
move -= MOVE;
assert(0 <= move && move < moves.size());
auto& [req, results] = moves[move];
auto& [req_groups, req_perm] = req;
vector<int> indices;
for (char g : req_groups) {
vector<group>& src = groups[group_index[g]];
assert(!src.empty());
indices.insert(indices.end(), src.back().begin(), src.back().end());
src.pop_back();
}
for (int i : req_perm) que.push_back(indices[i]);
auto& [res_groups, res_perm] = results[query(que)];
assert(indices.size() == res_perm.size());
vector<group> out_groups(res_groups.size());
vector<int> target;
for (int i : res_perm) target.push_back(indices[i]);
auto target_it = target.begin();
for (char g : res_groups) {
auto ntarget_it = target_it + group_sizes[group_index[g]];
groups[group_index[g]].emplace_back(target_it, ntarget_it);
target_it = ntarget_it;
}
assert(target_it == target.end());
}
}
using deltatype = vector<int>;
vector<pair<deltatype,vector<deltatype>>> all_deltas;
void init_deltas() {
for (auto& [req, results] : moves) {
auto& [req_groups, req_perm] = req;
deltatype reqd(groupc);
vector<deltatype> deltas;
for (char g : req_groups) reqd[group_index[g]]++;
for (auto& [res_query, res] : results) {
auto& [res_groups, res_perm] = res;
deltatype resd = reqd;
for (char g : res_groups) resd[group_index[g]]--;
deltas.push_back(resd);
}
sort_uniq(deltas);
all_deltas.emplace_back(reqd, deltas);
}
}
constexpr int Q = 226;
constexpr int T = 18;
vector<pair<int,char>> _best((Q+8)*(Q+8)*(Q+8)/3*T);
int ts[5][5][5][5];
void init_ts() {
int tc = 0;
for (int g0 = 0; g0 <= 3; g0++)
for (int g1 = 0; g1 <= 3; g1++)
for (int g2 = 0; g2 <= 3; g2++)
for (int g3 = 0; g3 <= 3; g3++)
if (g0 + g2 <= 1 && g1 + g3 <= 2) ts[g0][g1][g2][g3] = tc++;
}
int count_now(int q, int a, int b) {
assert(q >= 0);
if (a < b) swap(a, b);
int d = max(0, q + b - a);
return a * (q + 1) + b + (d >> 1) * (d + 1 >> 1);
}
int get_hsh(int q, const deltatype& d) {
int a = d[0], b = d[1], s = 0;
if (a < b) s = 1, swap(a, b);
q = Q - q + 5;
assert(q >= a && a >= b);
int hsh = q*(q+1)*(q+2)/6 + a*(a+1)/2 + b;
hsh = hsh * 2 + s;
hsh = hsh * T + ts[d[2]][d[3]][d[4]][d[5]];
assert(0 <= hsh && hsh < _best.size());
return hsh;
}
int best_solve(int q, const deltatype& d);
pair<int,int> best(int q, const deltatype& d) {
int a = d[0], b = d[1], g0 = d[2], g1 = d[3], g2 = d[4], g3 = d[5];
if (!(0 <= a && a <= Q - q + 5 && 0 <= b && b <= Q - q + 5 &&
0 <= g0 && 0 <= g1 && 0 <= g2 && 0 <= g3 &&
g0 + g2 <= 1 && g1 + g3 <= 2 &&
0 <= q && q <= Q)) return {-INF, GIVEUP};
if (a < b && g2 == 0) return {best_solve(q, {b, a, g0, g1, g2, g3}), SWAP};
if (Q < q) return {count_now(0, a, b), COUNT};
int hsh = get_hsh(q, d);
if (!_best[hsh].second) {
int solve = count_now(Q - q, a, b);
char move = COUNT;
auto try_cand = [&](int cmove) {
auto& [req, deltas] = all_deltas[cmove];
for (int i = 0; i < d.size(); i++) if (d[i] < req[i]) return;
int csolve = INF;
for (auto& delta : deltas) {
deltatype nd = d;
for (int i = 0; i < d.size(); i++) nd[i] -= delta[i];
csolve = min(csolve, best_solve(q + 1, nd));
}
if (solve < csolve) {
solve = csolve;
move = MOVE + cmove;
}
};
if (a >= 4 && b >= 4) {
try_cand(g3 ? 7 : g2 ? 6 : g1 ? 5 : g0 ? 4 : 3);
} else {
for (int i = 0; i < moves.size(); i++) try_cand(i);
}
_best[hsh] = {solve, move};
}
return _best[hsh];
}
int best_solve(int q, const deltatype& d) { return best(q, d).first; }
int best_move(int q, const deltatype& d) { return best(q, d).second; }
int count_mushrooms(int n) {
init_groups();
init_deltas();
init_ts();
for (int i = 0; i < n; i++) {
(i ? w : a).emplace_back(1, i);
}
while (w.size() >= 10) {
int move = best_move(q, {int(a.size()), int(b.size()), int(g0.size()), int(g1.size()), int(g2.size()), int(g3.size())});
if (move == COUNT) break;
do_move(move);
}
for (; !g0.empty(); g0.pop_back()) for (group& g = g0.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back());
for (; !g1.empty(); g1.pop_back()) for (group& g = g1.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back());
for (; !g2.empty(); g2.pop_back()) for (group& g = g2.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back());
for (; !g3.empty(); g3.pop_back()) for (group& g = g3.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back());
while (!w.empty()) do_move(a.size() < b.size() ? SWAP : COUNT);
while (!a.empty()) a.pop_back(), cA++;
while (!b.empty()) b.pop_back(), cB++;
for (vector<group>& g : groups) assert(g.empty());
assert(cA + cB == n);
return (A == 'A') * cA + (B == 'A') * cB;
}
Compilation message
mushrooms.cpp: In function 'void init_groups()':
mushrooms.cpp:117:65: warning: array subscript has type 'char' [-Wchar-subscripts]
117 | for (int i = 0; i < groupc; i++) group_index[group_labels[i]] = i;
| ^
In file included from /usr/include/c++/9/cassert:44,
from /usr/include/x86_64-linux-gnu/c++/9/bits/stdc++.h:33,
from mushrooms.cpp:2:
mushrooms.cpp: In function 'void do_move(int)':
mushrooms.cpp:153:34: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<std::pair<std::__cxx11::basic_string<char>, std::vector<int> >, std::map<int, std::pair<std::__cxx11::basic_string<char>, std::vector<int> > > > >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
153 | assert(0 <= move && move < moves.size());
| ~~~~~^~~~~~~~~~~~~~
mushrooms.cpp:159:53: warning: array subscript has type 'char' [-Wchar-subscripts]
159 | vector<group>& src = groups[group_index[g]];
| ^
mushrooms.cpp:172:67: warning: array subscript has type 'char' [-Wchar-subscripts]
172 | auto ntarget_it = target_it + group_sizes[group_index[g]];
| ^
mushrooms.cpp:173:32: warning: array subscript has type 'char' [-Wchar-subscripts]
173 | groups[group_index[g]].emplace_back(target_it, ntarget_it);
| ^
mushrooms.cpp: In function 'void init_deltas()':
mushrooms.cpp:189:52: warning: array subscript has type 'char' [-Wchar-subscripts]
189 | for (char g : req_groups) reqd[group_index[g]]++;
| ^
mushrooms.cpp:193:56: warning: array subscript has type 'char' [-Wchar-subscripts]
193 | for (char g : res_groups) resd[group_index[g]]--;
| ^
mushrooms.cpp: In function 'int count_now(int, int, int)':
mushrooms.cpp:219:44: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
219 | return a * (q + 1) + b + (d >> 1) * (d + 1 >> 1);
| ~~^~~
In file included from /usr/include/c++/9/cassert:44,
from /usr/include/x86_64-linux-gnu/c++/9/bits/stdc++.h:33,
from mushrooms.cpp:2:
mushrooms.cpp: In function 'int get_hsh(int, const deltatype&)':
mushrooms.cpp:230:28: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<int, char> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
230 | assert(0 <= hsh && hsh < _best.size());
| ~~~~^~~~~~~~~~~~~~
mushrooms.cpp: In lambda function:
mushrooms.cpp:249:31: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
249 | for (int i = 0; i < d.size(); i++) if (d[i] < req[i]) return;
| ~~^~~~~~~~~~
mushrooms.cpp:253:35: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
253 | for (int i = 0; i < d.size(); i++) nd[i] -= delta[i];
| ~~^~~~~~~~~~
mushrooms.cpp: In function 'std::pair<int, int> best(int, const deltatype&)':
mushrooms.cpp:265:31: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<std::pair<std::__cxx11::basic_string<char>, std::vector<int> >, std::map<int, std::pair<std::__cxx11::basic_string<char>, std::vector<int> > > > >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
265 | for (int i = 0; i < moves.size(); i++) try_cand(i);
| ~~^~~~~~~~~~~~~~
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
355 ms |
602104 KB |
Output is correct |
2 |
Correct |
355 ms |
602232 KB |
Output is correct |
3 |
Correct |
354 ms |
602232 KB |
Output is correct |
4 |
Correct |
902 ms |
602488 KB |
Output is correct |
5 |
Correct |
901 ms |
602252 KB |
Output is correct |
6 |
Correct |
935 ms |
602316 KB |
Output is correct |
7 |
Correct |
908 ms |
603444 KB |
Output is correct |
8 |
Correct |
928 ms |
603572 KB |
Output is correct |
9 |
Correct |
922 ms |
603444 KB |
Output is correct |
10 |
Correct |
942 ms |
603444 KB |
Output is correct |
11 |
Correct |
927 ms |
603576 KB |
Output is correct |
12 |
Correct |
941 ms |
603388 KB |
Output is correct |
13 |
Correct |
916 ms |
603488 KB |
Output is correct |
14 |
Correct |
906 ms |
602872 KB |
Output is correct |
15 |
Correct |
910 ms |
603444 KB |
Output is correct |
16 |
Correct |
906 ms |
603444 KB |
Output is correct |
17 |
Correct |
913 ms |
602872 KB |
Output is correct |
18 |
Correct |
932 ms |
603444 KB |
Output is correct |
19 |
Correct |
917 ms |
603444 KB |
Output is correct |
20 |
Correct |
920 ms |
603444 KB |
Output is correct |
21 |
Correct |
914 ms |
603444 KB |
Output is correct |
22 |
Correct |
920 ms |
603444 KB |
Output is correct |
23 |
Correct |
923 ms |
603396 KB |
Output is correct |
24 |
Correct |
919 ms |
602804 KB |
Output is correct |
25 |
Correct |
925 ms |
603444 KB |
Output is correct |
26 |
Correct |
937 ms |
603448 KB |
Output is correct |
27 |
Correct |
950 ms |
603580 KB |
Output is correct |
28 |
Correct |
919 ms |
603572 KB |
Output is correct |
29 |
Correct |
947 ms |
603532 KB |
Output is correct |
30 |
Correct |
918 ms |
603380 KB |
Output is correct |
31 |
Correct |
909 ms |
603444 KB |
Output is correct |
32 |
Correct |
915 ms |
603444 KB |
Output is correct |
33 |
Correct |
924 ms |
603580 KB |
Output is correct |
34 |
Correct |
922 ms |
603444 KB |
Output is correct |
35 |
Correct |
947 ms |
603388 KB |
Output is correct |
36 |
Correct |
921 ms |
603444 KB |
Output is correct |
37 |
Correct |
915 ms |
603452 KB |
Output is correct |
38 |
Correct |
912 ms |
603444 KB |
Output is correct |
39 |
Correct |
921 ms |
603444 KB |
Output is correct |
40 |
Correct |
927 ms |
603444 KB |
Output is correct |
41 |
Correct |
924 ms |
603444 KB |
Output is correct |
42 |
Correct |
914 ms |
603444 KB |
Output is correct |
43 |
Correct |
929 ms |
603444 KB |
Output is correct |
44 |
Correct |
935 ms |
603444 KB |
Output is correct |
45 |
Correct |
950 ms |
603444 KB |
Output is correct |
46 |
Correct |
940 ms |
603444 KB |
Output is correct |
47 |
Correct |
926 ms |
603444 KB |
Output is correct |
48 |
Correct |
931 ms |
603424 KB |
Output is correct |
49 |
Correct |
932 ms |
603444 KB |
Output is correct |
50 |
Correct |
914 ms |
603484 KB |
Output is correct |
51 |
Correct |
968 ms |
603572 KB |
Output is correct |
52 |
Correct |
924 ms |
603444 KB |
Output is correct |
53 |
Correct |
935 ms |
603444 KB |
Output is correct |
54 |
Correct |
960 ms |
603444 KB |
Output is correct |
55 |
Correct |
957 ms |
603444 KB |
Output is correct |
56 |
Correct |
955 ms |
603444 KB |
Output is correct |
57 |
Correct |
949 ms |
603444 KB |
Output is correct |
58 |
Correct |
947 ms |
603700 KB |
Output is correct |
59 |
Correct |
910 ms |
603444 KB |
Output is correct |
60 |
Correct |
907 ms |
603444 KB |
Output is correct |
61 |
Correct |
924 ms |
603448 KB |
Output is correct |
62 |
Correct |
352 ms |
602104 KB |
Output is correct |
63 |
Correct |
359 ms |
602232 KB |
Output is correct |
64 |
Correct |
354 ms |
602232 KB |
Output is correct |
65 |
Correct |
359 ms |
602232 KB |
Output is correct |
66 |
Correct |
355 ms |
602168 KB |
Output is correct |
67 |
Correct |
364 ms |
602360 KB |
Output is correct |
68 |
Correct |
351 ms |
602104 KB |
Output is correct |
69 |
Correct |
360 ms |
602104 KB |
Output is correct |
70 |
Correct |
352 ms |
602104 KB |
Output is correct |
71 |
Correct |
355 ms |
602232 KB |
Output is correct |
72 |
Correct |
353 ms |
602104 KB |
Output is correct |
73 |
Correct |
358 ms |
602232 KB |
Output is correct |
74 |
Correct |
354 ms |
602104 KB |
Output is correct |
75 |
Correct |
362 ms |
602232 KB |
Output is correct |
76 |
Correct |
357 ms |
602128 KB |
Output is correct |
77 |
Correct |
358 ms |
602104 KB |
Output is correct |
78 |
Correct |
351 ms |
602104 KB |
Output is correct |
79 |
Correct |
360 ms |
602232 KB |
Output is correct |
80 |
Correct |
354 ms |
602104 KB |
Output is correct |
81 |
Correct |
354 ms |
602360 KB |
Output is correct |
82 |
Correct |
352 ms |
602104 KB |
Output is correct |
83 |
Correct |
355 ms |
602232 KB |
Output is correct |
84 |
Correct |
357 ms |
602104 KB |
Output is correct |
85 |
Correct |
359 ms |
602360 KB |
Output is correct |
86 |
Correct |
352 ms |
602104 KB |
Output is correct |
87 |
Correct |
361 ms |
602232 KB |
Output is correct |
88 |
Correct |
359 ms |
602104 KB |
Output is correct |
89 |
Correct |
357 ms |
602232 KB |
Output is correct |
90 |
Correct |
359 ms |
602104 KB |
Output is correct |
91 |
Correct |
367 ms |
602104 KB |
Output is correct |