Submission #1223528

#TimeUsernameProblemLanguageResultExecution timeMemory
1223528steveonalexCounting Mushrooms (IOI20_mushrooms)C++20
100 / 100
47 ms11060 KiB
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
 
#define MASK(i) (1ULL << (i))
#define GETBIT(mask, i) (((mask) >> (i)) & 1)
#define ALL(v) (v).begin(), (v).end()
 
ll max(ll a, ll b){return (a > b) ? a : b;}
ll min(ll a, ll b){return (a < b) ? a : b;}
ll gcd(ll a, ll b){return __gcd(abs(a), abs(b));}
ll lcm(ll a, ll b){return abs(a) / gcd(a, b) * abs(b);}
 
ll LASTBIT(ll mask){return (mask) & (-mask);}
int pop_cnt(ull mask){return __builtin_popcountll(mask);}
int ctz(ull mask){return __builtin_ctzll(mask);}
int logOf(ull mask){return 63 - __builtin_clzll(mask);}
 
// mt19937_64 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
mt19937_64 rng(1);
ll rngesus(ll l, ll r){return l + (ull) rng() % (r - l + 1);}
double rngesus_d(double l, double r){
    double cur = rngesus(0, MASK(60) - 1);
    cur /= MASK(60) - 1;
    return l + cur * (r - l);
}
 
template <class T1, class T2>
    bool maximize(T1 &a, T2 b){
        if (a < b) {a = b; return true;}
        return false;
    }
 
template <class T1, class T2>
    bool minimize(T1 &a, T2 b){
        if (a > b) {a = b; return true;}
        return false;
    }
 
template <class T>
    void printArr(T container, string separator = " ", string finish = "\n", ostream &out = cout){
        for(auto item: container) out << item << separator;
        out << finish;
    }
 
template <class T>
    void remove_dup(vector<T> &a){
        sort(ALL(a));
        a.resize(unique(ALL(a)) - a.begin());
    }

#include "mushrooms.h"

namespace QueryStuff{
    const int N = 10;
    bool inited = false;
    vector<vector<int>> queries_set[N];
    int f[N];

    void init(){
        if (inited) return;
        inited = true;

        f[0] = 1;
        queries_set[0].push_back({0});

        for(int layer = 1; layer < N; ++layer){
            f[layer] = 2 * f[layer-1] + MASK(layer) - 1;
            
            vector<int> cur;
            for(int j = 0; j < f[layer - 1]; ++j) cur.push_back(j + f[layer - 1]);
            queries_set[layer].push_back(cur);
            for(int i = 0; i < (int) queries_set[layer - 1].size(); ++i){
                vector<int> cur = queries_set[layer - 1][i];

                vector<int> cur1 = cur;
                for(int j: cur) cur1.push_back(f[layer-1] + j);

                queries_set[layer].push_back(cur1);

                vector<int> cur2 = cur;
                for(int j = 0; j < f[layer - 1]; ++j) if (!binary_search(ALL(queries_set[layer - 1][i]), j)){
                    cur2.push_back(j + f[layer - 1]);
                }
                cur2.push_back(f[layer - 1] * 2 + i);

                queries_set[layer].push_back(cur2);
            }
        }
    }

    vector<int> figure_out(int layer, vector<int> perm, vector<int> query_answer){
        vector<int> ans(perm.size());
        int c = query_answer[0];
        if (layer == 0) {
            ans[0] = c;
            return ans;
        }
        vector<int> query_answer_left, query_answer_right;
        vector<int> perm_left, perm_right;
        for(int i = 0; i < f[layer-1]; ++i) {
            perm_left.push_back(perm[i]);
            perm_right.push_back(perm[i + f[layer-1]]);
        }

        for(int i = 1; i < (int) query_answer.size(); i += 2){
            int pos = f[layer - 1] * 2 + (i / 2);
            int s1 = query_answer[i], s2 = query_answer[i + 1];
            int s3 = s1 + s2 - c, s4 = s1 - s2 + c;
            if (s3 % 2){
                ans[pos] = 1;
                s3--; s4++;
            }
            if (s3 % 2 || s4 % 2) assert(false);
            query_answer_left.push_back(s3 / 2); query_answer_right.push_back(s4 / 2);
        }
        vector<int> left = figure_out(layer - 1, perm_left, query_answer_left);
        vector<int> right = figure_out(layer - 1, perm_right, query_answer_right);

        for(int i = 0; i < f[layer - 1]; ++i){
            ans[i] = left[i];
            ans[i + f[layer-1]] = right[i];
        }


        return ans;
    }
}

namespace Sub1{
    int count_mushrooms(int n){
        const int BLOCK = 90;
        int cnt = 1;
        vector<int> red, blue;
        red.push_back(0);

        vector<int> perm;
        for(int i = 1; i < n; ++i) perm.push_back(i);

        while(perm.size()){
            int u = perm.back(); perm.pop_back();
            if (use_machine({0, u})){
                blue.push_back(u);
            }
            else red.push_back(u);

            if (blue.size() >= 2 || red.size() >= 2) break;
        }
        bool flipped = (blue.size() >= 2);
        if (flipped) swap(red, blue);
        while(perm.size()){
            vector<int> cur;
            for(int i = 0; i < 2; ++i){
                if (perm.empty()) break;
                cur.push_back(perm.back()); perm.pop_back();
                cur.push_back(red[i]);
            }
            int wow = use_machine(cur);
            if (GETBIT(wow, 0)) blue.push_back(cur[0]);
            else red.push_back(cur[0]);

            if (cur.size() >= 4){
                if (GETBIT(wow, 1)) blue.push_back(cur[2]);
                else red.push_back(cur[2]);
            }

            if (max(blue.size(), red.size()) >= BLOCK) break;
        }
        if (flipped) swap(red, blue);

        cnt = red.size();

        while(perm.size()){
            if (blue.size() > red.size()){
                vector<int> cur;
                for(int i = 0; i < (int) blue.size(); ++i){
                    if (perm.empty()) break;
                    cur.push_back(perm.back()); perm.pop_back();
                    cur.push_back(blue[i]);
                }
                int wow = use_machine(cur);
                if (GETBIT(wow, 0)) red.push_back(cur[0]);
                else blue.push_back(cur[0]);

                cnt += wow / 2 + wow % 2;
            }
            else{
                vector<int> cur;
                for(int i = 0; i < (int) red.size(); ++i){
                    if (perm.empty()) break;
                    cur.push_back(perm.back()); perm.pop_back();
                    cur.push_back(red[i]);
                }
                int wow = use_machine(cur);
                wow = cur.size() - 1 - wow;
                if (GETBIT(wow, 0)) red.push_back(cur[0]);
                else blue.push_back(cur[0]);

                cnt += wow / 2 + wow % 2;

            }
        }
        return cnt;
    }
}

namespace Sub2{
    int count_mushrooms(int n){
        // we want to run around BLOCK times
        int BLOCK = 90;

        vector<int> red, blue;
        red.push_back(0);

        vector<int> perm;
        for(int i = 1; i < n; ++i) perm.push_back(i);

        while(perm.size()){
            int u = perm.back(); perm.pop_back();
            if (use_machine({0, u})){
                blue.push_back(u);
            }
            else red.push_back(u);

            if (blue.size() >= 2 || red.size() >= 2) break;
        }


        while(BLOCK > 0){
            bool flipped = (blue.size() >= red.size());
            if (flipped) swap(red, blue);

            int layer = 0;
            while(QueryStuff::f[layer+1] <= (int)red.size() - 1 && (int)(MASK(layer+2) - 1) <= BLOCK) {
                layer++;
            }

            int cur_block = QueryStuff::f[layer], cur_op = MASK(layer+1)-1;
            BLOCK -= cur_op;

            vector<int> arr;
            for(int i = 0; i < cur_block; ++i) {
                arr.push_back(perm.back());
                perm.pop_back();
            }

            vector<int> flank;
            for(int i = 0; i < cur_op; ++i) {
                flank.push_back(perm.back());
                perm.pop_back();
            }

            vector<vector<int>> query_list = QueryStuff::queries_set[layer];
            vector<int> query_answer;
            for(auto qq: query_list){
                vector<int> cur; 
                cur.push_back(flank.back()); cur.push_back(red.back());
                flank.pop_back();
                for(int i = 0; i < (int) qq.size(); ++i){
                    cur.push_back(arr[qq[i]]);
                    cur.push_back(red[i]);
                }

                int num = use_machine(cur);
                if (num % 2) blue.push_back(cur[0]);
                else red.push_back(cur[0]);

                query_answer.push_back(num / 2);
            }

            vector<int> ans = QueryStuff::figure_out(layer, arr, query_answer);
            for(int i = 0; i < (int) ans.size(); ++i) {
                if (ans[i]) {
                    blue.push_back(arr[i]);
                }
                else red.push_back(arr[i]);
            }

            if (flipped) swap(red, blue);
        }

        int cnt = red.size();

        while(perm.size()){
            if (blue.size() > red.size()){
                vector<int> cur;
                for(int i = 0; i < (int) blue.size(); ++i){
                    if (perm.empty()) break;
                    cur.push_back(perm.back()); perm.pop_back();
                    cur.push_back(blue[i]);
                }
                int wow = use_machine(cur);
                if (GETBIT(wow, 0)) red.push_back(cur[0]);
                else blue.push_back(cur[0]);

                cnt += wow / 2 + wow % 2;
            }
            else{
                vector<int> cur;
                for(int i = 0; i < (int) red.size(); ++i){
                    if (perm.empty()) break;
                    cur.push_back(perm.back()); perm.pop_back();
                    cur.push_back(red[i]);
                }
                int wow = use_machine(cur);
                wow = cur.size() - 1 - wow;
                if (GETBIT(wow, 0)) red.push_back(cur[0]);
                else blue.push_back(cur[0]);

                cnt += wow / 2 + wow % 2;
            }
        }
        return cnt;
    }
}

int count_mushrooms(int n){
    QueryStuff::init();
    if (n <= 16000) return Sub1::count_mushrooms(n);
    return Sub2::count_mushrooms(n);
}
#Verdict Execution timeMemoryGrader output
Fetching results...