Submission #304368

#TimeUsernameProblemLanguageResultExecution timeMemory
304368kevinsogoCounting Mushrooms (IOI20_mushrooms)C++17
100 / 100
968 ms603700 KiB
#include "mushrooms.h" #include <bits/stdc++.h> using namespace std; const int GIVEUP = 0; const int COUNT = 1; const int SWAP = 2; const int MOVE = 3; const int INF = 1 << 30; using group = vector<int>; using distrib = pair<string,vector<int>>; using movement = pair<distrib,map<int,distrib>>; template<class T> void sort_uniq(vector<T>& v) { sort(v.begin(), v.end()); v.erase(unique(v.begin(), v.end()), v.end()); } vector<movement> moves = { // A.... {{"AWWWW",{0,1,2,3,4}}, { {0, {"AAAAA",{0,1,2,3,4}}}, {1, {"A1B", {0,1,2,3,4}}}, {2, {"A2A", {0,1,2,3,4}}}, {3, {"A3B", {0,1,2,3,4}}}, {4, {"ABABA",{0,1,2,3,4}}}, }}, // .A...A {{"WAWWWA",{0,1,2,3,4,5}}, { {0, {"AAAAAA",{0,1,2,3,4,5}}}, {1, {"BAAAAA",{0,1,2,3,4,5}}}, {2, {"AA2A", {0,1,2,3,4,5}}}, {3, {"BA2A", {0,1,2,3,4,5}}}, {4, {"AABABA",{0,1,2,3,4,5}}}, {5, {"BABABA",{0,1,2,3,4,5}}}, }}, // .A.A.A {{"WAWAWA",{0,1,2,3,4,5}}, { {0, {"AAAAAA",{0,1,2,3,4,5}}}, {1, {"BAAAAA",{0,1,2,3,4,5}}}, {2, {"AA0AA", {0,1,2,4,3,5}}}, {3, {"BA0AA", {0,1,2,4,3,5}}}, {4, {"AABABA",{0,1,2,3,4,5}}}, {5, {"BABABA",{0,1,2,3,4,5}}}, }}, // .A...B {{"WAWWWB",{0,1,2,3,4,5}}, { {1, {"AA1B",{0,1,2,3,4,5}}}, {2, {"BA1B",{0,1,2,3,4,5}}}, {3, {"AA3B",{0,1,2,3,4,5}}}, {4, {"BA3B",{0,1,2,3,4,5}}}, }}, // 0A...B {{"0AWWWB",{0,2,3,4,5,6}}, { {1, {"AA1BB",{0,2,3,4,5,6,1}}}, {2, {"BA1BA",{0,2,3,4,5,6,1}}}, {3, {"AA3BB",{0,2,3,4,5,6,1}}}, {4, {"BA3BA",{0,2,3,4,5,6,1}}}, }}, // .A0A1A2A {{"WA1AAA",{0,1,2,5,3,6,4,7}}, { {0, {"AAAAAAAA",{0,1,2,5,3,6,4,7}}}, {1, {"BAAAAAAA",{0,1,2,5,3,6,4,7}}}, {2, {"AAAAAABA",{0,1,2,5,3,6,4,7}}}, {3, {"BAAAAABA",{0,1,2,5,3,6,4,7}}}, {4, {"AAAABABA",{0,1,2,5,3,6,4,7}}}, {5, {"BAAABABA",{0,1,2,5,3,6,4,7}}}, {6, {"AABABABA",{0,1,2,5,3,6,4,7}}}, {7, {"BABABABA",{0,1,2,5,3,6,4,7}}}, }}, // 0A1A {{"2AA",{0,3,1,4}}, { {0, {"AAAAB",{0,3,1,4,2}}}, {1, {"BAAAA",{0,3,1,4,2}}}, {2, {"AABAW",{0,3,1,4,2}}}, {3, {"BABAW",{0,3,1,4,2}}}, }}, // .A0A2AB1B {{"WA3AABB",{0,1,2,5,4,6,7,3,8}}, { {1, {"AAAAAABBB",{0,1,2,5,4,6,7,3,8}}}, {2, {"BAAAAABBB",{0,1,2,5,4,6,7,3,8}}}, {3, {"AABAAABBB",{0,1,2,5,4,6,7,3,8}}}, {4, {"BABAAABBB",{0,1,2,5,4,6,7,3,8}}}, {5, {"AABAAABAB",{0,1,2,5,4,6,7,3,8}}}, {6, {"BABAAABAB",{0,1,2,5,4,6,7,3,8}}}, {7, {"AABABABAB",{0,1,2,5,4,6,7,3,8}}}, {8, {"BABABABAB",{0,1,2,5,4,6,7,3,8}}}, }}, }; const string group_labels = "AB0123W"; const int groupc = group_labels.size(); const vector<int> group_sizes = {1, 1, 2, 3, 3, 3, 1}; int group_index[256]; vector<vector<group>> groups(7); vector<group>& a = groups[0]; vector<group>& b = groups[1]; vector<group>& g0 = groups[2]; vector<group>& g1 = groups[3]; vector<group>& g2 = groups[4]; vector<group>& g3 = groups[5]; vector<group>& w = groups[6]; char A, B; int cA, cB, q; void init_groups() { A = 'A', B = 'B'; cA = cB = q = 0; for (vector<group>& g : groups) g.clear(); for (int i = 0; i < groupc; i++) group_index[group_labels[i]] = i; } int query(const vector<int>& que) { assert(que.size() > 0); return q++, use_machine(que); } void do_move(int move) { vector<int> que; if (move == COUNT) { int take = min(w.size(), a.size()); for (int i = 0; i < take; i++) { que.push_back(w.back().back()); w.pop_back(); que.insert(que.end(), a[i].begin(), a[i].end()); } int res = query(que); groups[res & 1].emplace_back(1, que.front()); res >>= 1; assert(0 <= res && res < take); cA += take - 1 - res; cB += res; } else if (move == SWAP) { assert(g2.empty()); swap(A, B); swap(cA, cB); swap(a, b); for (group& vec : g1) reverse(vec.begin(), vec.end()); for (group& vec : g3) reverse(vec.begin(), vec.end()); } else { move -= MOVE; assert(0 <= move && move < moves.size()); auto& [req, results] = moves[move]; auto& [req_groups, req_perm] = req; vector<int> indices; for (char g : req_groups) { vector<group>& src = groups[group_index[g]]; assert(!src.empty()); indices.insert(indices.end(), src.back().begin(), src.back().end()); src.pop_back(); } for (int i : req_perm) que.push_back(indices[i]); auto& [res_groups, res_perm] = results[query(que)]; assert(indices.size() == res_perm.size()); vector<group> out_groups(res_groups.size()); vector<int> target; for (int i : res_perm) target.push_back(indices[i]); auto target_it = target.begin(); for (char g : res_groups) { auto ntarget_it = target_it + group_sizes[group_index[g]]; groups[group_index[g]].emplace_back(target_it, ntarget_it); target_it = ntarget_it; } assert(target_it == target.end()); } } using deltatype = vector<int>; vector<pair<deltatype,vector<deltatype>>> all_deltas; void init_deltas() { for (auto& [req, results] : moves) { auto& [req_groups, req_perm] = req; deltatype reqd(groupc); vector<deltatype> deltas; for (char g : req_groups) reqd[group_index[g]]++; for (auto& [res_query, res] : results) { auto& [res_groups, res_perm] = res; deltatype resd = reqd; for (char g : res_groups) resd[group_index[g]]--; deltas.push_back(resd); } sort_uniq(deltas); all_deltas.emplace_back(reqd, deltas); } } constexpr int Q = 226; constexpr int T = 18; vector<pair<int,char>> _best((Q+8)*(Q+8)*(Q+8)/3*T); int ts[5][5][5][5]; void init_ts() { int tc = 0; for (int g0 = 0; g0 <= 3; g0++) for (int g1 = 0; g1 <= 3; g1++) for (int g2 = 0; g2 <= 3; g2++) for (int g3 = 0; g3 <= 3; g3++) if (g0 + g2 <= 1 && g1 + g3 <= 2) ts[g0][g1][g2][g3] = tc++; } int count_now(int q, int a, int b) { assert(q >= 0); if (a < b) swap(a, b); int d = max(0, q + b - a); return a * (q + 1) + b + (d >> 1) * (d + 1 >> 1); } int get_hsh(int q, const deltatype& d) { int a = d[0], b = d[1], s = 0; if (a < b) s = 1, swap(a, b); q = Q - q + 5; assert(q >= a && a >= b); int hsh = q*(q+1)*(q+2)/6 + a*(a+1)/2 + b; hsh = hsh * 2 + s; hsh = hsh * T + ts[d[2]][d[3]][d[4]][d[5]]; assert(0 <= hsh && hsh < _best.size()); return hsh; } int best_solve(int q, const deltatype& d); pair<int,int> best(int q, const deltatype& d) { int a = d[0], b = d[1], g0 = d[2], g1 = d[3], g2 = d[4], g3 = d[5]; if (!(0 <= a && a <= Q - q + 5 && 0 <= b && b <= Q - q + 5 && 0 <= g0 && 0 <= g1 && 0 <= g2 && 0 <= g3 && g0 + g2 <= 1 && g1 + g3 <= 2 && 0 <= q && q <= Q)) return {-INF, GIVEUP}; if (a < b && g2 == 0) return {best_solve(q, {b, a, g0, g1, g2, g3}), SWAP}; if (Q < q) return {count_now(0, a, b), COUNT}; int hsh = get_hsh(q, d); if (!_best[hsh].second) { int solve = count_now(Q - q, a, b); char move = COUNT; auto try_cand = [&](int cmove) { auto& [req, deltas] = all_deltas[cmove]; for (int i = 0; i < d.size(); i++) if (d[i] < req[i]) return; int csolve = INF; for (auto& delta : deltas) { deltatype nd = d; for (int i = 0; i < d.size(); i++) nd[i] -= delta[i]; csolve = min(csolve, best_solve(q + 1, nd)); } if (solve < csolve) { solve = csolve; move = MOVE + cmove; } }; if (a >= 4 && b >= 4) { try_cand(g3 ? 7 : g2 ? 6 : g1 ? 5 : g0 ? 4 : 3); } else { for (int i = 0; i < moves.size(); i++) try_cand(i); } _best[hsh] = {solve, move}; } return _best[hsh]; } int best_solve(int q, const deltatype& d) { return best(q, d).first; } int best_move(int q, const deltatype& d) { return best(q, d).second; } int count_mushrooms(int n) { init_groups(); init_deltas(); init_ts(); for (int i = 0; i < n; i++) { (i ? w : a).emplace_back(1, i); } while (w.size() >= 10) { int move = best_move(q, {int(a.size()), int(b.size()), int(g0.size()), int(g1.size()), int(g2.size()), int(g3.size())}); if (move == COUNT) break; do_move(move); } for (; !g0.empty(); g0.pop_back()) for (group& g = g0.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back()); for (; !g1.empty(); g1.pop_back()) for (group& g = g1.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back()); for (; !g2.empty(); g2.pop_back()) for (group& g = g2.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back()); for (; !g3.empty(); g3.pop_back()) for (group& g = g3.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back()); while (!w.empty()) do_move(a.size() < b.size() ? SWAP : COUNT); while (!a.empty()) a.pop_back(), cA++; while (!b.empty()) b.pop_back(), cB++; for (vector<group>& g : groups) assert(g.empty()); assert(cA + cB == n); return (A == 'A') * cA + (B == 'A') * cB; }

Compilation message (stderr)

mushrooms.cpp: In function 'void init_groups()':
mushrooms.cpp:117:65: warning: array subscript has type 'char' [-Wchar-subscripts]
  117 |     for (int i = 0; i < groupc; i++) group_index[group_labels[i]] = i;
      |                                                                 ^
In file included from /usr/include/c++/9/cassert:44,
                 from /usr/include/x86_64-linux-gnu/c++/9/bits/stdc++.h:33,
                 from mushrooms.cpp:2:
mushrooms.cpp: In function 'void do_move(int)':
mushrooms.cpp:153:34: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<std::pair<std::__cxx11::basic_string<char>, std::vector<int> >, std::map<int, std::pair<std::__cxx11::basic_string<char>, std::vector<int> > > > >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  153 |         assert(0 <= move && move < moves.size());
      |                             ~~~~~^~~~~~~~~~~~~~
mushrooms.cpp:159:53: warning: array subscript has type 'char' [-Wchar-subscripts]
  159 |             vector<group>& src = groups[group_index[g]];
      |                                                     ^
mushrooms.cpp:172:67: warning: array subscript has type 'char' [-Wchar-subscripts]
  172 |             auto ntarget_it = target_it + group_sizes[group_index[g]];
      |                                                                   ^
mushrooms.cpp:173:32: warning: array subscript has type 'char' [-Wchar-subscripts]
  173 |             groups[group_index[g]].emplace_back(target_it, ntarget_it);
      |                                ^
mushrooms.cpp: In function 'void init_deltas()':
mushrooms.cpp:189:52: warning: array subscript has type 'char' [-Wchar-subscripts]
  189 |         for (char g : req_groups) reqd[group_index[g]]++;
      |                                                    ^
mushrooms.cpp:193:56: warning: array subscript has type 'char' [-Wchar-subscripts]
  193 |             for (char g : res_groups) resd[group_index[g]]--;
      |                                                        ^
mushrooms.cpp: In function 'int count_now(int, int, int)':
mushrooms.cpp:219:44: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
  219 |     return a * (q + 1) + b + (d >> 1) * (d + 1 >> 1);
      |                                          ~~^~~
In file included from /usr/include/c++/9/cassert:44,
                 from /usr/include/x86_64-linux-gnu/c++/9/bits/stdc++.h:33,
                 from mushrooms.cpp:2:
mushrooms.cpp: In function 'int get_hsh(int, const deltatype&)':
mushrooms.cpp:230:28: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<int, char> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  230 |     assert(0 <= hsh && hsh < _best.size());
      |                        ~~~~^~~~~~~~~~~~~~
mushrooms.cpp: In lambda function:
mushrooms.cpp:249:31: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  249 |             for (int i = 0; i < d.size(); i++) if (d[i] < req[i]) return;
      |                             ~~^~~~~~~~~~
mushrooms.cpp:253:35: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  253 |                 for (int i = 0; i < d.size(); i++) nd[i] -= delta[i];
      |                                 ~~^~~~~~~~~~
mushrooms.cpp: In function 'std::pair<int, int> best(int, const deltatype&)':
mushrooms.cpp:265:31: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<std::pair<std::__cxx11::basic_string<char>, std::vector<int> >, std::map<int, std::pair<std::__cxx11::basic_string<char>, std::vector<int> > > > >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  265 |             for (int i = 0; i < moves.size(); i++) try_cand(i);
      |                             ~~^~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...