Submission #805477

#TimeUsernameProblemLanguageResultExecution timeMemory
805477PixelCatCounting Mushrooms (IOI20_mushrooms)C++14
80.43 / 100
8 ms312 KiB
#include "mushrooms.h"

#ifdef NYAOWO
#include "stub.cpp"
#endif

#include <bits/stdc++.h>
#define For(i, a, b) for(int i = a; i <= b; i++)
#define Forr(i, a, b) for(int i = a; i >= b; i--)
#define F first
#define S second
#define all(x) x.begin(), x.end()
#define sz(x) ((int)x.size())
#define eb emplace_back
// #define int LL
using namespace std;
using i32 = int32_t;
using LL = long long;
using pii = pair<int, int>;

int alter(vector<int> jury, vector<int> sus, bool flip) {
    assert(sz(jury) >= sz(sus));
    vector<int> q;
    int n = sz(sus);
    while(sz(sus)) {
        q.eb(jury.back()); jury.pop_back();
        q.eb(sus.back()); sus.pop_back();
    }
    int res = use_machine(q);
    res = (res + 1) / 2;
    if(flip) return res;
    return n - res;
}

pii ask2(int j1, int j2, int x1, int x2, int flip) {
    int cnt = use_machine({j1, x1, j2, x2});
    pii res;
    res.F = ((cnt & 2) != 0);
    res.S = ((cnt & 1) != 0);
    if(flip) {
        res.F = 1 - res.F;
        res.S = 1 - res.S;
    }
    return res;
}

bool check(int s0, int s1, int ss) {
    return max(s0, s1) < 140;
    // if(max(s0, s1) < 10) return true;
    // int mxs;
    // // dont ask?
    // mxs = max(s0, s1);
    // int cost0 = (ss + mxs - 1) / mxs;
    // // ask?
    // mxs = min({max(s0, s1 + 2), max(s0 + 1, s1 + 1), max(s0 + 2, s1)});
    // int cost1 = (ss - 2 + mxs - 1) / mxs + 1;
    // return cost1 < cost0;
}

int count_mushrooms(int n) {
    if(n <= 227) {
        int ans = 1;
        For(i, 1, n - 1) {
            ans += 1 - use_machine({0, i});
        }
        return ans;
    }
    
    vector<int> v0(1, 0);
    vector<int> v1;
    For(i, 1, 2) {
        if(use_machine({0, i})) v1.eb(i);
        else v0.eb(i);
        if(max(sz(v0), sz(v1)) == 2) break;
    }
    int ptr = sz(v0) + sz(v1);

    for(; ptr + 1 < n && check(sz(v0), sz(v1), n - ptr);) {
        pii res;
        if(sz(v1) > sz(v0)) {
            res = ask2(v1[0], v1[1], ptr, ptr + 1, true);
        } else {
            res = ask2(v0[0], v0[1], ptr, ptr + 1, false);
        }
        (res.F ? v1 : v0).eb(ptr);
        (res.S ? v1 : v0).eb(ptr + 1);
        ptr += 2;
    }

    // for(auto &i:v0) cerr << i << " ";
    // cerr << "\n";
    // for(auto &i:v1) cerr << i << " ";
    // cerr << "\n";
    
    int ans = sz(v0);
    int flip = 0;
    if(sz(v0) < sz(v1)) {
        v0.swap(v1);
        flip = 1;
    }
    while(ptr < n) {
        int r = min(n - ptr, sz(v0));
        vector<int> sus;
        For(i, 1, r) {
            sus.eb(ptr);
            ptr++;
        }
        ans += alter(v0, sus, flip);
    }
    return ans;
    // std::vector<int> m;
    // for (int i = 0; i < n; i++)
    // 	m.push_back(i);
    // int c1 = use_machine(m);
    // m = {0, 1};
    // int c2 = use_machine(m);
    // return c1+c2;
}

/*

3
0 1 1

4
0 1 0 0

*/
#Verdict Execution timeMemoryGrader output
Fetching results...