Submission #304368

#TimeUsernameProblemLanguageResultExecution timeMemory
304368kevinsogoCounting Mushrooms (IOI20_mushrooms)C++17
100 / 100
968 ms603700 KiB
#include "mushrooms.h"
#include <bits/stdc++.h>
using namespace std;

const int GIVEUP = 0;
const int COUNT = 1;
const int SWAP = 2;
const int MOVE = 3;
const int INF = 1 << 30;
using group = vector<int>;
using distrib = pair<string,vector<int>>;
using movement = pair<distrib,map<int,distrib>>;

template<class T>
void sort_uniq(vector<T>& v) {
    sort(v.begin(), v.end());
    v.erase(unique(v.begin(), v.end()), v.end());
}

vector<movement> moves = {
    // A....
    {{"AWWWW",{0,1,2,3,4}}, {
        {0, {"AAAAA",{0,1,2,3,4}}},
        {1, {"A1B",  {0,1,2,3,4}}},
        {2, {"A2A",  {0,1,2,3,4}}},
        {3, {"A3B",  {0,1,2,3,4}}},
        {4, {"ABABA",{0,1,2,3,4}}},
    }},

    // .A...A
    {{"WAWWWA",{0,1,2,3,4,5}}, {
        {0, {"AAAAAA",{0,1,2,3,4,5}}},
        {1, {"BAAAAA",{0,1,2,3,4,5}}},
        {2, {"AA2A",  {0,1,2,3,4,5}}},
        {3, {"BA2A",  {0,1,2,3,4,5}}},
        {4, {"AABABA",{0,1,2,3,4,5}}},
        {5, {"BABABA",{0,1,2,3,4,5}}},
    }},

    // .A.A.A
    {{"WAWAWA",{0,1,2,3,4,5}}, {
        {0, {"AAAAAA",{0,1,2,3,4,5}}},
        {1, {"BAAAAA",{0,1,2,3,4,5}}},
        {2, {"AA0AA", {0,1,2,4,3,5}}},
        {3, {"BA0AA", {0,1,2,4,3,5}}},
        {4, {"AABABA",{0,1,2,3,4,5}}},
        {5, {"BABABA",{0,1,2,3,4,5}}},
    }},

    // .A...B
    {{"WAWWWB",{0,1,2,3,4,5}}, {
        {1, {"AA1B",{0,1,2,3,4,5}}},
        {2, {"BA1B",{0,1,2,3,4,5}}},
        {3, {"AA3B",{0,1,2,3,4,5}}},
        {4, {"BA3B",{0,1,2,3,4,5}}},
    }},

    // 0A...B
    {{"0AWWWB",{0,2,3,4,5,6}}, {
        {1, {"AA1BB",{0,2,3,4,5,6,1}}},
        {2, {"BA1BA",{0,2,3,4,5,6,1}}},
        {3, {"AA3BB",{0,2,3,4,5,6,1}}},
        {4, {"BA3BA",{0,2,3,4,5,6,1}}},
    }},

    // .A0A1A2A
    {{"WA1AAA",{0,1,2,5,3,6,4,7}}, {
        {0, {"AAAAAAAA",{0,1,2,5,3,6,4,7}}},
        {1, {"BAAAAAAA",{0,1,2,5,3,6,4,7}}},
        {2, {"AAAAAABA",{0,1,2,5,3,6,4,7}}},
        {3, {"BAAAAABA",{0,1,2,5,3,6,4,7}}},
        {4, {"AAAABABA",{0,1,2,5,3,6,4,7}}},
        {5, {"BAAABABA",{0,1,2,5,3,6,4,7}}},
        {6, {"AABABABA",{0,1,2,5,3,6,4,7}}},
        {7, {"BABABABA",{0,1,2,5,3,6,4,7}}},
    }},

    // 0A1A
    {{"2AA",{0,3,1,4}}, {
        {0, {"AAAAB",{0,3,1,4,2}}},
        {1, {"BAAAA",{0,3,1,4,2}}},
        {2, {"AABAW",{0,3,1,4,2}}},
        {3, {"BABAW",{0,3,1,4,2}}},
    }},

    // .A0A2AB1B
    {{"WA3AABB",{0,1,2,5,4,6,7,3,8}}, {
        {1, {"AAAAAABBB",{0,1,2,5,4,6,7,3,8}}},
        {2, {"BAAAAABBB",{0,1,2,5,4,6,7,3,8}}},
        {3, {"AABAAABBB",{0,1,2,5,4,6,7,3,8}}},
        {4, {"BABAAABBB",{0,1,2,5,4,6,7,3,8}}},
        {5, {"AABAAABAB",{0,1,2,5,4,6,7,3,8}}},
        {6, {"BABAAABAB",{0,1,2,5,4,6,7,3,8}}},
        {7, {"AABABABAB",{0,1,2,5,4,6,7,3,8}}},
        {8, {"BABABABAB",{0,1,2,5,4,6,7,3,8}}},
    }},
};

const string group_labels = "AB0123W";
const int groupc = group_labels.size();
const vector<int> group_sizes = {1, 1, 2, 3, 3, 3, 1};
int group_index[256];
vector<vector<group>> groups(7);
vector<group>& a = groups[0];
vector<group>& b = groups[1];
vector<group>& g0 = groups[2];
vector<group>& g1 = groups[3];
vector<group>& g2 = groups[4];
vector<group>& g3 = groups[5];
vector<group>& w = groups[6];
char A, B;
int cA, cB, q;
void init_groups() {
    A = 'A', B = 'B';
    cA = cB = q = 0;
    for (vector<group>& g : groups) g.clear();
    for (int i = 0; i < groupc; i++) group_index[group_labels[i]] = i;
}

int query(const vector<int>& que) {
    assert(que.size() > 0);
    return q++, use_machine(que);
}

void do_move(int move) {
    vector<int> que;
    if (move == COUNT) {

        int take = min(w.size(), a.size());
        for (int i = 0; i < take; i++) {
            que.push_back(w.back().back()); w.pop_back();
            que.insert(que.end(), a[i].begin(), a[i].end());
        }
        int res = query(que);
        groups[res & 1].emplace_back(1, que.front());
        res >>= 1;
        assert(0 <= res && res < take);
        cA += take - 1 - res;
        cB += res;

    } else if (move == SWAP) {

        assert(g2.empty());
        swap(A, B);
        swap(cA, cB);
        swap(a, b);
        for (group& vec : g1) reverse(vec.begin(), vec.end());
        for (group& vec : g3) reverse(vec.begin(), vec.end());

    } else {

        move -= MOVE;
        assert(0 <= move && move < moves.size());

        auto& [req, results] = moves[move];
        auto& [req_groups, req_perm] = req;
        vector<int> indices;
        for (char g : req_groups) {
            vector<group>& src = groups[group_index[g]];
            assert(!src.empty());
            indices.insert(indices.end(), src.back().begin(), src.back().end());
            src.pop_back();
        }
        for (int i : req_perm) que.push_back(indices[i]);
        auto& [res_groups, res_perm] = results[query(que)];
        assert(indices.size() == res_perm.size());
        vector<group> out_groups(res_groups.size());
        vector<int> target;
        for (int i : res_perm) target.push_back(indices[i]);
        auto target_it = target.begin();
        for (char g : res_groups) {
            auto ntarget_it = target_it + group_sizes[group_index[g]];
            groups[group_index[g]].emplace_back(target_it, ntarget_it);
            target_it = ntarget_it;
        }
        assert(target_it == target.end());

    }
}


using deltatype = vector<int>;
vector<pair<deltatype,vector<deltatype>>> all_deltas;
void init_deltas() {
    for (auto& [req, results] : moves) {
        auto& [req_groups, req_perm] = req;
        deltatype reqd(groupc);
        vector<deltatype> deltas;
        for (char g : req_groups) reqd[group_index[g]]++;
        for (auto& [res_query, res] : results) {
            auto& [res_groups, res_perm] = res;
            deltatype resd = reqd;
            for (char g : res_groups) resd[group_index[g]]--;
            deltas.push_back(resd);
        }
        sort_uniq(deltas);
        all_deltas.emplace_back(reqd, deltas);
    }
}

constexpr int Q = 226;
constexpr int T = 18;
vector<pair<int,char>> _best((Q+8)*(Q+8)*(Q+8)/3*T);

int ts[5][5][5][5];
void init_ts() {
    int tc = 0;
    for (int g0 = 0; g0 <= 3; g0++)
    for (int g1 = 0; g1 <= 3; g1++)
    for (int g2 = 0; g2 <= 3; g2++)
    for (int g3 = 0; g3 <= 3; g3++)
        if (g0 + g2 <= 1 && g1 + g3 <= 2) ts[g0][g1][g2][g3] = tc++;
}

int count_now(int q, int a, int b) {
    assert(q >= 0);
    if (a < b) swap(a, b);
    int d = max(0, q + b - a);
    return a * (q + 1) + b + (d >> 1) * (d + 1 >> 1);
}

int get_hsh(int q, const deltatype& d) {
    int a = d[0], b = d[1], s = 0;
    if (a < b) s = 1, swap(a, b);
    q = Q - q + 5;
    assert(q >= a && a >= b);
    int hsh = q*(q+1)*(q+2)/6 + a*(a+1)/2 + b;
    hsh = hsh * 2 + s;
    hsh = hsh * T + ts[d[2]][d[3]][d[4]][d[5]];
    assert(0 <= hsh && hsh < _best.size());
    return hsh;
}

int best_solve(int q, const deltatype& d);
pair<int,int> best(int q, const deltatype& d) {
    int a = d[0], b = d[1], g0 = d[2], g1 = d[3], g2 = d[4], g3 = d[5];
    if (!(0 <= a && a <= Q - q + 5 && 0 <= b && b <= Q - q + 5 &&
            0 <= g0 && 0 <= g1 && 0 <= g2 && 0 <= g3 &&
            g0 + g2 <= 1 && g1 + g3 <= 2 &&
            0 <= q && q <= Q)) return {-INF, GIVEUP};
    if (a < b && g2 == 0) return {best_solve(q, {b, a, g0, g1, g2, g3}), SWAP};
    if (Q < q) return {count_now(0, a, b), COUNT};
    int hsh = get_hsh(q, d);
    if (!_best[hsh].second) {
        int solve = count_now(Q - q, a, b);
        char move = COUNT;
        auto try_cand = [&](int cmove) {
            auto& [req, deltas] = all_deltas[cmove];
            for (int i = 0; i < d.size(); i++) if (d[i] < req[i]) return;
            int csolve = INF;
            for (auto& delta : deltas) {
                deltatype nd = d;
                for (int i = 0; i < d.size(); i++) nd[i] -= delta[i];
                csolve = min(csolve, best_solve(q + 1, nd));
            }
            if (solve < csolve) {
                solve = csolve;
                move = MOVE + cmove;
            }
        };

        if (a >= 4 && b >= 4) {
            try_cand(g3 ? 7 : g2 ? 6 : g1 ? 5 : g0 ? 4 : 3);
        } else {
            for (int i = 0; i < moves.size(); i++) try_cand(i);
        }
        _best[hsh] = {solve, move};
    }
    return _best[hsh];
}
int best_solve(int q, const deltatype& d) { return best(q, d).first; }
int best_move(int q, const deltatype& d) { return best(q, d).second; }

int count_mushrooms(int n) {
    init_groups();
    init_deltas();
    init_ts();
    for (int i = 0; i < n; i++) {
        (i ? w : a).emplace_back(1, i);
    }

    while (w.size() >= 10) {
        int move = best_move(q, {int(a.size()), int(b.size()), int(g0.size()), int(g1.size()), int(g2.size()), int(g3.size())});
        if (move == COUNT) break;
        do_move(move);
    }

    for (; !g0.empty(); g0.pop_back()) for (group& g = g0.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back());
    for (; !g1.empty(); g1.pop_back()) for (group& g = g1.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back());
    for (; !g2.empty(); g2.pop_back()) for (group& g = g2.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back());
    for (; !g3.empty(); g3.pop_back()) for (group& g = g3.back(); !g.empty(); g.pop_back()) w.emplace_back(1, g.back());
    while (!w.empty()) do_move(a.size() < b.size() ? SWAP : COUNT);
    while (!a.empty()) a.pop_back(), cA++;
    while (!b.empty()) b.pop_back(), cB++;

    for (vector<group>& g : groups) assert(g.empty());
    assert(cA + cB == n);
    return (A == 'A') * cA + (B == 'A') * cB;
}

Compilation message (stderr)

mushrooms.cpp: In function 'void init_groups()':
mushrooms.cpp:117:65: warning: array subscript has type 'char' [-Wchar-subscripts]
  117 |     for (int i = 0; i < groupc; i++) group_index[group_labels[i]] = i;
      |                                                                 ^
In file included from /usr/include/c++/9/cassert:44,
                 from /usr/include/x86_64-linux-gnu/c++/9/bits/stdc++.h:33,
                 from mushrooms.cpp:2:
mushrooms.cpp: In function 'void do_move(int)':
mushrooms.cpp:153:34: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<std::pair<std::__cxx11::basic_string<char>, std::vector<int> >, std::map<int, std::pair<std::__cxx11::basic_string<char>, std::vector<int> > > > >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  153 |         assert(0 <= move && move < moves.size());
      |                             ~~~~~^~~~~~~~~~~~~~
mushrooms.cpp:159:53: warning: array subscript has type 'char' [-Wchar-subscripts]
  159 |             vector<group>& src = groups[group_index[g]];
      |                                                     ^
mushrooms.cpp:172:67: warning: array subscript has type 'char' [-Wchar-subscripts]
  172 |             auto ntarget_it = target_it + group_sizes[group_index[g]];
      |                                                                   ^
mushrooms.cpp:173:32: warning: array subscript has type 'char' [-Wchar-subscripts]
  173 |             groups[group_index[g]].emplace_back(target_it, ntarget_it);
      |                                ^
mushrooms.cpp: In function 'void init_deltas()':
mushrooms.cpp:189:52: warning: array subscript has type 'char' [-Wchar-subscripts]
  189 |         for (char g : req_groups) reqd[group_index[g]]++;
      |                                                    ^
mushrooms.cpp:193:56: warning: array subscript has type 'char' [-Wchar-subscripts]
  193 |             for (char g : res_groups) resd[group_index[g]]--;
      |                                                        ^
mushrooms.cpp: In function 'int count_now(int, int, int)':
mushrooms.cpp:219:44: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
  219 |     return a * (q + 1) + b + (d >> 1) * (d + 1 >> 1);
      |                                          ~~^~~
In file included from /usr/include/c++/9/cassert:44,
                 from /usr/include/x86_64-linux-gnu/c++/9/bits/stdc++.h:33,
                 from mushrooms.cpp:2:
mushrooms.cpp: In function 'int get_hsh(int, const deltatype&)':
mushrooms.cpp:230:28: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<int, char> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  230 |     assert(0 <= hsh && hsh < _best.size());
      |                        ~~~~^~~~~~~~~~~~~~
mushrooms.cpp: In lambda function:
mushrooms.cpp:249:31: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  249 |             for (int i = 0; i < d.size(); i++) if (d[i] < req[i]) return;
      |                             ~~^~~~~~~~~~
mushrooms.cpp:253:35: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  253 |                 for (int i = 0; i < d.size(); i++) nd[i] -= delta[i];
      |                                 ~~^~~~~~~~~~
mushrooms.cpp: In function 'std::pair<int, int> best(int, const deltatype&)':
mushrooms.cpp:265:31: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<std::pair<std::__cxx11::basic_string<char>, std::vector<int> >, std::map<int, std::pair<std::__cxx11::basic_string<char>, std::vector<int> > > > >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  265 |             for (int i = 0; i < moves.size(); i++) try_cand(i);
      |                             ~~^~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...