#include <bits/stdc++.h>
#include "mushrooms.h"
using namespace std;
const int magic = 100;
const int bound = 226;
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
// generate random number between l, r : uniform_int_distribution<long long>(l, r)(rng)
// random shuffle : shuffle(.begin(), .end(), rng)
int count_mushrooms(int n) {
if (n <= bound + 1) {
int ret = 1;
for (int i = 1; i < n; i++) {
ret += use_machine({0, i}) ^ 1;
}
return ret;
}
vector<int> a(n, -1);
a[0] = 0;
vector<int> zeros = {0};
vector<int> ones;
int cnt = 1;
while (max(zeros.size(), ones.size()) < magic) {
vector<int> ids = zeros.size() >= ones.size() ? zeros : ones;
int memo = cnt;
int L = cnt, R = cnt + min(ids.size(), magic - ids.size()) - 1;
cnt = R + 1;
int bit = ids == ones;
vector<int> cur;
for (int i = L; i <= R; i++) cur.push_back(i);
shuffle(cur.begin(), cur.end(), rng);
vector<int> que;
for (int i = 0; i < (int) cur.size(); i++) {
que.push_back(ids[i]);
que.push_back(cur[i]);
}
int sum = use_machine(que);
a[cur.back()] = bit ^ (sum & 1);
sum -= a[cur.back()] ^ bit;
sum >>= 1;
cur.pop_back();
sum = bit == 1 ? cur.size() - sum : sum;
function<void(vector<int>, int)> solve = [&](vector<int> cur, int sum) {
assert(sum >= 0 && sum <= (int) cur.size());
if (sum == 0) {
for (int i : cur) a[i] = 0;
return;
}
if (sum == (int) cur.size()) {
for (int i : cur) a[i] = 1;
return;
}
assert(cur.size() >= 2);
vector<int> Lq, Rq;
int mid = cur.size() / 2;
for (int i = 0; i < (int) cur.size(); i++) {
if (i < mid) Lq.push_back(cur[i]);
else Rq.push_back(cur[i]);
}
vector<int> p;
for (int i = 0; i < (int) Lq.size(); i++) {
p.push_back(ids[i]);
p.push_back(Lq[i]);
}
p.push_back(ids[Lq.size()]);
p.push_back(cnt++);
int Lsum = use_machine(p);
a[cnt - 1] = bit ^ (Lsum & 1);
Lsum -= a[cnt - 1] ^ bit;
Lsum >>= 1;
Lsum = bit == 1 ? Lq.size() - Lsum : Lsum;
solve(Lq, Lsum);
solve(Rq, sum - Lsum);
};
solve(cur, sum);
for (int i = memo; i < cnt; i++) {
assert(a[i] == A[i]);
if (a[i] == 0) zeros.push_back(i);
else ones.push_back(i);
}
}
int ret = 0;
ret += zeros.size();
while (cnt < n) {
int len = min(n - cnt, (int) max(zeros.size(), ones.size()));
vector<int> ids = zeros.size() >= ones.size() ? zeros : ones;
vector<int> que;
for (int i = 0; i < len; i++) {
que.push_back(ids[i]);
que.push_back(cnt++);
}
int st = use_machine(que);
if (ids == ones) {
a[cnt - 1] = 1 ^ (st & 1);
ret += a[cnt - 1] == 0;
st -= a[cnt - 1] == 0;
if (a[cnt - 1] == 0) zeros.push_back(cnt - 1);
else ones.push_back(cnt - 1);
assert(st % 2 == 0);
ret += st >> 1;
} else {
a[cnt - 1] = 0 ^ (st & 1);
ret += a[cnt - 1] == 0;
st -= a[cnt - 1] == 1;
if (a[cnt - 1] == 0) zeros.push_back(cnt - 1);
else ones.push_back(cnt - 1);
assert(st % 2 == 0);
ret += (len - 1) - (st >> 1);
}
}
return ret;
}
Compilation message
In file included from /usr/include/c++/10/cassert:44,
from /usr/include/x86_64-linux-gnu/c++/10/bits/stdc++.h:33,
from mushrooms.cpp:1:
mushrooms.cpp: In function 'int count_mushrooms(int)':
mushrooms.cpp:98:22: error: 'A' was not declared in this scope
98 | assert(a[i] == A[i]);
| ^