Submission #304368

# Submission time Handle Problem Language Result Execution time Memory
304368 2020-09-21T08:13:55 Z kevinsogo Counting Mushrooms (IOI20_mushrooms) C++17
100 / 100
968 ms 603700 KB
#include "mushrooms.h"
#include <bits/stdc++.h>
using namespace std;

const int GIVEUP = 0;
const int COUNT = 1;
const int SWAP = 2;
const int MOVE = 3;
const int INF = 1 << 30;
using group = vector<int>;
using distrib = pair<string,vector<int>>;
using movement = pair<distrib,map<int,distrib>>;

template<class T>
void sort_uniq(vector<T>& v) {
    sort(v.begin(), v.end());
    v.erase(unique(v.begin(), v.end()), v.end());
}

vector<movement> moves = {
    // A....
    {{"AWWWW",{0,1,2,3,4}}, {
        {0, {"AAAAA",{0,1,2,3,4}}},
        {1, {"A1B",  {0,1,2,3,4}}},
        {2, {"A2A",  {0,1,2,3,4}}},
        {3, {"A3B",  {0,1,2,3,4}}},
        {4, {"ABABA",{0,1,2,3,4}}},
    }},

    // .A...A
    {{"WAWWWA",{0,1,2,3,4,5}}, {
        {0, {"AAAAAA",{0,1,2,3,4,5}}},
        {1, {"BAAAAA",{0,1,2,3,4,5}}},
        {2, {"AA2A",  {0,1,2,3,4,5}}},
        {3, {"BA2A",  {0,1,2,3,4,5}}},
        {4, {"AABABA",{0,1,2,3,4,5}}},
        {5, {"BABABA",{0,1,2,3,4,5}}},
    }},

    // .A.A.A
    {{"WAWAWA",{0,1,2,3,4,5}}, {
        {0, {"AAAAAA",{0,1,2,3,4,5}}},
        {1, {"BAAAAA",{0,1,2,3,4,5}}},
        {2, {"AA0AA", {0,1,2,4,3,5}}},
        {3, {"BA0AA", {0,1,2,4,3,5}}},
        {4, {"AABABA",{0,1,2,3,4,5}}},
        {5, {"BABABA",{0,1,2,3,4,5}}},
    }},

    // .A...B
    {{"WAWWWB",{0,1,2,3,4,5}}, {
        {1, {"AA1B",{0,1,2,3,4,5}}},
        {2, {"BA1B",{0,1,2,3,4,5}}},
        {3, {"AA3B",{0,1,2,3,4,5}}},
        {4, {"BA3B",{0,1,2,3,4,5}}},
    }},

    // 0A...B
    {{"0AWWWB",{0,2,3,4,5,6}}, {
        {1, {"AA1BB",{0,2,3,4,5,6,1}}},
        {2, {"BA1BA",{0,2,3,4,5,6,1}}},
        {3, {"AA3BB",{0,2,3,4,5,6,1}}},
        {4, {"BA3BA",{0,2,3,4,5,6,1}}},
    }},

    // .A0A1A2A
    {{"WA1AAA",{0,1,2,5,3,6,4,7}}, {
        {0, {"AAAAAAAA",{0,1,2,5,3,6,4,7}}},
        {1, {"BAAAAAAA",{0,1,2,5,3,6,4,7}}},
        {2, {"AAAAAABA",{0,1,2,5,3,6,4,7}}},
        {3, {"BAAAAABA",{0,1,2,5,3,6,4,7}}},
        {4, {"AAAABABA",{0,1,2,5,3,6,4,7}}},
        {5, {"BAAABABA",{0,1,2,5,3,6,4,7}}},
        {6, {"AABABABA",{0,1,2,5,3,6,4,7}}},
        {7, {"BABABABA",{0,1,2,5,3,6,4,7}}},
    }},

    // 0A1A
    {{"2AA",{0,3,1,4}}, {
        {0, {"AAAAB",{0,3,1,4,2}}},
        {1, {"BAAAA",{0,3,1,4,2}}},
        {2, {"AABAW",{0,3,1,4,2}}},
        {3, {"BABAW",{0,3,1,4,2}}},
    }},

    // .A0A2AB1B
    {{"WA3AABB",{0,1,2,5,4,6,7,3,8}}, {
        {1, {"AAAAAABBB",{0,1,2,5,4,6,7,3,8}}},
        {2, {"BAAAAABBB",{0,1,2,5,4,6,7,3,8}}},
        {3, {"AABAAABBB",{0,1,2,5,4,6,7,3,8}}},
        {4, {"BABAAABBB",{0,1,2,5,4,6,7,3,8}}},
        {5, {"AABAAABAB",{0,1,2,5,4,6,7,3,8}}},
        {6, {"BABAAABAB",{0,1,2,5,4,6,7,3,8}}},
        {7, {"AABABABAB",{0,1,2,5,4,6,7,3,8}}},
        {8, {"BABABABAB",{0,1,2,5,4,6,7,3,8}}},
    }},
};

const string group_labels = "AB0123W";
const int groupc = group_labels.size();
const vector<int> group_sizes = {1, 1, 2, 3, 3, 3, 1};
int group_index[256];
vector<vector<group>> groups(7);
vector<group>& a = groups[0];
vector<group>& b = groups[1];
vector<group>& g0 = groups[2];
vector<group>& g1 = groups[3];
vector<group>& g2 = groups[4];
vector<group>& g3 = groups[5];
vector<group>& w = groups[6];
char A, B;
int cA, cB, q;
void init_groups() {
    A = 'A', B = 'B';
    cA = cB = q = 0;
    for (vector<group>& g : groups) g.clear();
    for (int i = 0; i < groupc; i++) group_index[group_labels[i]] = i;
}

int query(const vector<int>& que) {
    assert(que.size() > 0);
    return q++, use_machine(que);
}

void do_move(int move) {
    vector<int> que;
    if (move == COUNT) {

        int take = min(w.size(), a.size());
        for (int i = 0; i < take; i++) {
            que.push_back(w.back().back()); w.pop_back();
            que.insert(que.end(), a[i].begin(), a[i].end());
        }
        int res = query(que);
        groups[res & 1].emplace_back(1, que.front());
        res >>= 1;
        assert(0 <= res && res < take);
        cA += take - 1 - res;
        cB += res;

    } else if (move == SWAP) {

        assert(g2.empty());
        swap(A, B);
        swap(cA, cB);
        swap(a, b);
        for (group& vec : g1) reverse(vec.begin(), vec.end());
        for (group& vec : g3) reverse(vec.begin(), vec.end());

    } else {

        move -= MOVE;
        assert(0 <= move && move < moves.size());

        auto& [req, results] = moves[move];
        auto& [req_groups, req_perm] = req;
        vector<int> indices;
        for (char g : req_groups) {
            vector<group>& src = groups[group_index[g]];
            assert(!src.empty());
            indices.insert(indices.end(), src.back().begin(), src.back().end());
            src.pop_back();
        }
        for (int i : req_perm) que.push_back(indices[i]);
        auto& [res_groups, res_perm] = results[query(que)];
        assert(indices.size() == res_perm.size());
        vector<group> out_groups(res_groups.size());
        vector<int> target;
        for (int i : res_perm) target.push_back(indices[i]);
        auto target_it = target.begin();
        for (char g : res_groups) {
            auto ntarget_it = target_it + group_sizes[group_index[g]];
            groups[group_index[g]].emplace_back(target_it, ntarget_it);
            target_it = ntarget_it;
        }
        assert(target_it == target.end());

    }
}


using deltatype = vector<int>;
vector<pair<deltatype,vector<deltatype>>> all_deltas;
void init_deltas() {
    for (auto& [req, results] : moves) {
        auto& [req_groups, req_perm] = req;
        deltatype reqd(groupc);
        vector<deltatype> deltas;
        for (char g : req_groups) reqd[group_index[g]]++;
        for (auto& [res_query, res] : results) {
            auto& [res_groups, res_perm] = res;
            deltatype resd = reqd;
            for (char g : res_groups) resd[group_index[g]]--;
            deltas.push_back(resd);
        }
        sort_uniq(deltas);
        all_deltas.emplace_back(reqd, deltas);
    }
}

constexpr int Q = 226;
constexpr int T = 18;
vector<pair<int,char>> _best((Q+8)*(Q+8)*(Q+8)/3*T);

int ts[5][5][5][5];
void init_ts() {
    int tc = 0;
    for (int g0 = 0; g0 <= 3; g0++)
    for (int g1 = 0; g1 <= 3; g1++)
    for (int g2 = 0; g2 <= 3; g2++)
    for (int g3 = 0; g3 <= 3; g3++)
        if (g0 + g2 <= 1 && g1 + g3 <= 2) ts[g0][g1][g2][g3] = tc++;
}

int count_now(int q, int a, int b) {
    assert(q >= 0);
    if (a < b) swap(a, b);
    int d = max(0, q + b - a);
    return a * (q + 1) + b + (d >> 1) * (d + 1 >> 1);
}

int get_hsh(int q, const deltatype& d) {
    int a = d[0], b = d[1], s = 0;
    if (a < b) s = 1, swap(a, b);
    q = Q - q + 5;
    assert(q >= a && a >= b);
    int hsh = q*(q+1)*(q+2)/6 + a*(a+1)/2 + b;
    hsh = hsh * 2 + s;
    hsh = hsh * T + ts[d[2]][d[3]][d[4]][d[5]];
    assert(0 <= hsh && hsh < _best.size());
    return hsh;
}

int best_solve(int q, const deltatype& d);
pair<int,int> best(int q, const deltatype& d) {
    int a = d[0], b = d[1], g0 = d[2], g1 = d[3], g2 = d[4], g3 = d[5];
    if (!(0 <= a && a <= Q - q + 5 && 0 <= b && b <= Q - q + 5 &&
            0 <= g0 && 0 <= g1 && 0 <= g2 && 0 <= g3 &&
            g0 + g2 <= 1 && g1 + g3 <= 2 &&
            0 <= q && q <= Q)) return {-INF, GIVEUP};
    if (a < b && g2 == 0) return {best_solve(q, {b, a, g0, g1, g2, g3}), SWAP};
    if (Q < q) return {count_now(0, a, b), COUNT};
    int hsh = get_hsh(q, d);
    if (!_best[hsh].second) {
        int solve = count_now(Q - q, a, b);
        char move = COUNT;
        auto try_cand = [&](int cmove) {
            auto& [req, deltas] = all_deltas[cmove];
            for (int i = 0; i < d.size(); i++) if (d[i] < req[i]) return;
            int csolve = INF;
            for (auto& delta : deltas) {
                deltatype nd = d;
                for (int i = 0; i < d.size(); i++) nd[i] -= delta[i];
                csolve = min(csolve, best_solve(q + 1, nd));
            }
            if (solve < csolve) {
                solve = csolve;
                move = MOVE + cmove;
            }
        };

        if (a >= 4 && b >= 4) {
            try_cand(g3 ? 7 : g2 ? 6 : g1 ? 5 : g0 ? 4 : 3);
        } else {
            for (int i = 0; i < moves.size(); i++) try_cand(i);
        }
        _best[hsh] = {solve, move};
    }
    return _best[hsh];
}
int best_solve(int q, const deltatype& d) { return best(q, d).first; }
int best_move(int q, const deltatype& d) { return best(q, d).second; }

int count_mushrooms(int n) {
    init_groups();
    init_deltas();
    init_ts();
    for (int i = 0; i < n; i++) {
        (i ? w : a).emplace_back(1, i);
    }

    while (w.size() >= 10) {
        int move = best_move(q, {int(a.size()), int(b.size()), int(g0.size()), int(g1.size()), int(g2.size()), int(g3.size())});
        if (move == COUNT) break;
        do_move(move);
    }

    for (; !g0.empty(); g0.pop_back()) for (group& g = g0.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back());
    for (; !g1.empty(); g1.pop_back()) for (group& g = g1.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back());
    for (; !g2.empty(); g2.pop_back()) for (group& g = g2.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back());
    for (; !g3.empty(); g3.pop_back()) for (group& g = g3.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back());
    while (!w.empty()) do_move(a.size() < b.size() ? SWAP : COUNT);
    while (!a.empty()) a.pop_back(), cA++;
    while (!b.empty()) b.pop_back(), cB++;

    for (vector<group>& g : groups) assert(g.empty());
    assert(cA + cB == n);
    return (A == 'A') * cA + (B == 'A') * cB;
}

Compilation message

mushrooms.cpp: In function 'void init_groups()':
mushrooms.cpp:117:65: warning: array subscript has type 'char' [-Wchar-subscripts]
  117 |     for (int i = 0; i < groupc; i++) group_index[group_labels[i]] = i;
      |                                                                 ^
In file included from /usr/include/c++/9/cassert:44,
                 from /usr/include/x86_64-linux-gnu/c++/9/bits/stdc++.h:33,
                 from mushrooms.cpp:2:
mushrooms.cpp: In function 'void do_move(int)':
mushrooms.cpp:153:34: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<std::pair<std::__cxx11::basic_string<char>, std::vector<int> >, std::map<int, std::pair<std::__cxx11::basic_string<char>, std::vector<int> > > > >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  153 |         assert(0 <= move && move < moves.size());
      |                             ~~~~~^~~~~~~~~~~~~~
mushrooms.cpp:159:53: warning: array subscript has type 'char' [-Wchar-subscripts]
  159 |             vector<group>& src = groups[group_index[g]];
      |                                                     ^
mushrooms.cpp:172:67: warning: array subscript has type 'char' [-Wchar-subscripts]
  172 |             auto ntarget_it = target_it + group_sizes[group_index[g]];
      |                                                                   ^
mushrooms.cpp:173:32: warning: array subscript has type 'char' [-Wchar-subscripts]
  173 |             groups[group_index[g]].emplace_back(target_it, ntarget_it);
      |                                ^
mushrooms.cpp: In function 'void init_deltas()':
mushrooms.cpp:189:52: warning: array subscript has type 'char' [-Wchar-subscripts]
  189 |         for (char g : req_groups) reqd[group_index[g]]++;
      |                                                    ^
mushrooms.cpp:193:56: warning: array subscript has type 'char' [-Wchar-subscripts]
  193 |             for (char g : res_groups) resd[group_index[g]]--;
      |                                                        ^
mushrooms.cpp: In function 'int count_now(int, int, int)':
mushrooms.cpp:219:44: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
  219 |     return a * (q + 1) + b + (d >> 1) * (d + 1 >> 1);
      |                                          ~~^~~
In file included from /usr/include/c++/9/cassert:44,
                 from /usr/include/x86_64-linux-gnu/c++/9/bits/stdc++.h:33,
                 from mushrooms.cpp:2:
mushrooms.cpp: In function 'int get_hsh(int, const deltatype&)':
mushrooms.cpp:230:28: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<int, char> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  230 |     assert(0 <= hsh && hsh < _best.size());
      |                        ~~~~^~~~~~~~~~~~~~
mushrooms.cpp: In lambda function:
mushrooms.cpp:249:31: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  249 |             for (int i = 0; i < d.size(); i++) if (d[i] < req[i]) return;
      |                             ~~^~~~~~~~~~
mushrooms.cpp:253:35: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  253 |                 for (int i = 0; i < d.size(); i++) nd[i] -= delta[i];
      |                                 ~~^~~~~~~~~~
mushrooms.cpp: In function 'std::pair<int, int> best(int, const deltatype&)':
mushrooms.cpp:265:31: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<std::pair<std::__cxx11::basic_string<char>, std::vector<int> >, std::map<int, std::pair<std::__cxx11::basic_string<char>, std::vector<int> > > > >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  265 |             for (int i = 0; i < moves.size(); i++) try_cand(i);
      |                             ~~^~~~~~~~~~~~~~
# Verdict Execution time Memory Grader output
1 Correct 355 ms 602104 KB Output is correct
2 Correct 355 ms 602232 KB Output is correct
3 Correct 354 ms 602232 KB Output is correct
4 Correct 902 ms 602488 KB Output is correct
5 Correct 901 ms 602252 KB Output is correct
6 Correct 935 ms 602316 KB Output is correct
7 Correct 908 ms 603444 KB Output is correct
8 Correct 928 ms 603572 KB Output is correct
9 Correct 922 ms 603444 KB Output is correct
10 Correct 942 ms 603444 KB Output is correct
11 Correct 927 ms 603576 KB Output is correct
12 Correct 941 ms 603388 KB Output is correct
13 Correct 916 ms 603488 KB Output is correct
14 Correct 906 ms 602872 KB Output is correct
15 Correct 910 ms 603444 KB Output is correct
16 Correct 906 ms 603444 KB Output is correct
17 Correct 913 ms 602872 KB Output is correct
18 Correct 932 ms 603444 KB Output is correct
19 Correct 917 ms 603444 KB Output is correct
20 Correct 920 ms 603444 KB Output is correct
21 Correct 914 ms 603444 KB Output is correct
22 Correct 920 ms 603444 KB Output is correct
23 Correct 923 ms 603396 KB Output is correct
24 Correct 919 ms 602804 KB Output is correct
25 Correct 925 ms 603444 KB Output is correct
26 Correct 937 ms 603448 KB Output is correct
27 Correct 950 ms 603580 KB Output is correct
28 Correct 919 ms 603572 KB Output is correct
29 Correct 947 ms 603532 KB Output is correct
30 Correct 918 ms 603380 KB Output is correct
31 Correct 909 ms 603444 KB Output is correct
32 Correct 915 ms 603444 KB Output is correct
33 Correct 924 ms 603580 KB Output is correct
34 Correct 922 ms 603444 KB Output is correct
35 Correct 947 ms 603388 KB Output is correct
36 Correct 921 ms 603444 KB Output is correct
37 Correct 915 ms 603452 KB Output is correct
38 Correct 912 ms 603444 KB Output is correct
39 Correct 921 ms 603444 KB Output is correct
40 Correct 927 ms 603444 KB Output is correct
41 Correct 924 ms 603444 KB Output is correct
42 Correct 914 ms 603444 KB Output is correct
43 Correct 929 ms 603444 KB Output is correct
44 Correct 935 ms 603444 KB Output is correct
45 Correct 950 ms 603444 KB Output is correct
46 Correct 940 ms 603444 KB Output is correct
47 Correct 926 ms 603444 KB Output is correct
48 Correct 931 ms 603424 KB Output is correct
49 Correct 932 ms 603444 KB Output is correct
50 Correct 914 ms 603484 KB Output is correct
51 Correct 968 ms 603572 KB Output is correct
52 Correct 924 ms 603444 KB Output is correct
53 Correct 935 ms 603444 KB Output is correct
54 Correct 960 ms 603444 KB Output is correct
55 Correct 957 ms 603444 KB Output is correct
56 Correct 955 ms 603444 KB Output is correct
57 Correct 949 ms 603444 KB Output is correct
58 Correct 947 ms 603700 KB Output is correct
59 Correct 910 ms 603444 KB Output is correct
60 Correct 907 ms 603444 KB Output is correct
61 Correct 924 ms 603448 KB Output is correct
62 Correct 352 ms 602104 KB Output is correct
63 Correct 359 ms 602232 KB Output is correct
64 Correct 354 ms 602232 KB Output is correct
65 Correct 359 ms 602232 KB Output is correct
66 Correct 355 ms 602168 KB Output is correct
67 Correct 364 ms 602360 KB Output is correct
68 Correct 351 ms 602104 KB Output is correct
69 Correct 360 ms 602104 KB Output is correct
70 Correct 352 ms 602104 KB Output is correct
71 Correct 355 ms 602232 KB Output is correct
72 Correct 353 ms 602104 KB Output is correct
73 Correct 358 ms 602232 KB Output is correct
74 Correct 354 ms 602104 KB Output is correct
75 Correct 362 ms 602232 KB Output is correct
76 Correct 357 ms 602128 KB Output is correct
77 Correct 358 ms 602104 KB Output is correct
78 Correct 351 ms 602104 KB Output is correct
79 Correct 360 ms 602232 KB Output is correct
80 Correct 354 ms 602104 KB Output is correct
81 Correct 354 ms 602360 KB Output is correct
82 Correct 352 ms 602104 KB Output is correct
83 Correct 355 ms 602232 KB Output is correct
84 Correct 357 ms 602104 KB Output is correct
85 Correct 359 ms 602360 KB Output is correct
86 Correct 352 ms 602104 KB Output is correct
87 Correct 361 ms 602232 KB Output is correct
88 Correct 359 ms 602104 KB Output is correct
89 Correct 357 ms 602232 KB Output is correct
90 Correct 359 ms 602104 KB Output is correct
91 Correct 367 ms 602104 KB Output is correct