#include <bits/stdc++.h>
#define forsn(i, s, n) for (int i = int(s); i < int(n); i++)
#define forn(i, n) forsn(i, 0, n)
#define dforsn(i, s, n) for (int i = int(n) - 1; i >= int(s); i--)
#define dforn(i, n) dforsn(i, 0, n)
#define sz(x) int(x.size())
#define all(x) begin(x), end(x)
#define pb push_back
#define eb emplace_back
#define fst first
#define snd second
using namespace std;
using vi = vector<int>;
using ii = pair<int, int>;
using ll = long long;
int use_machine(vi x);
tuple<vi, vector<vi>, vector<vi>> bf() {
const int TYPE[2] = {5, 6};
auto getRta = [&](const vi &order, int mask) {
vi values(9);
forn(i, 9) {
if (order[i] < 5) values[i] = TYPE[mask >> order[i] & 1];
else values[i] = order[i];
}
int rta = 0;
forn(i, 8) rta += (values[i] != values[i + 1]);
return rta;
};
vi q1 = {0, 1, 2, TYPE[0], TYPE[1], 3, TYPE[1], TYPE[1], 4};
vi rta(1 << 5);
forn(mask, 1 << 5) rta[mask] = getRta(q1, mask);
vector<vi> q2(9);
bool good = true;
forn(prevRta, 9) {
vi otherOrder = {0, 1, 2, 3, 4, TYPE[0], TYPE[1], TYPE[1], TYPE[1]};
bool found = false;
do {
bool flag = true;
vector<bool> used(9, false);
forn(mask, 1 << 5) if (rta[mask] == prevRta) {
int currRta = getRta(otherOrder, mask);
if (!used[currRta]) {
used[currRta] = true;
} else {
flag = false;
break;
}
}
if (flag) {
q2[prevRta] = otherOrder;
found = true;
break;
}
} while (next_permutation(all(otherOrder)));
if (!found) {
good = false;
break;
}
}
if (good) {
vector<vi> maskByRtas(9, vi(9, -1));
forn(mask, 1 << 5) {
int rta1 = getRta(q1, mask);
int rta2 = getRta(q2[rta1], mask);
assert(maskByRtas[rta1][rta2] == -1);
maskByRtas[rta1][rta2] = mask;
}
return {q1, q2, maskByRtas};
}
assert(false);
}
void fix(vi &a) {
int cnt = 0;
forn(i, sz(a)) {
if (a[i] == 6) a[i] += cnt, cnt++;
}
assert(cnt == 3);
}
const int K = 90;
int count_mushrooms(int n) {
vi t[2] = {{0}, {}};
int p = 1;
while (p < n && max(sz(t[0]), sz(t[1])) < 2) {
t[use_machine({0, p})].pb(p++);
}
while (max(sz(t[0]), sz(t[1])) < K && (min(sz(t[0]), sz(t[1])) < 1 || max(sz(t[0]), sz(t[1])) < 3) && p + 1 < n) {
int id = sz(t[1]) >= 2;
assert(sz(t[id]) >= 2);
int ans = use_machine({t[id][0], p, t[id][1], p + 1});
t[(ans >> 1 & 1) ^ id].pb(p);
t[(ans & 1) ^ id].pb(p + 1);
p += 2;
}
auto [q1, q2, maskByRtas] = bf();
fix(q1);
forn(i, 9) fix(q2[i]);
while (p + 4 < n && max(sz(t[0]), sz(t[1])) < K) {
int id = sz(t[1]) >= 3;
assert(sz(t[id]) >= 3);
assert(sz(t[!id]) >= 1);
vi v;
forn(i, 5) v.pb(p + i);
v.pb(t[!id][0]);
forn(i, 3) v.pb(t[id][i]);
vi q;
for (int i : q1) q.pb(v[i]);
int rta1 = use_machine(q);
q.clear();
for (int i : q2[rta1]) q.pb(v[i]);
int rta2 = use_machine(q);
int mask = maskByRtas[rta1][rta2];
forn(i, 5) t[(mask >> i & 1) ^ !id].pb(p + i);
p += 5;
}
int res = sz(t[0]);
while (p < n) {
int id = sz(t[1]) > sz(t[0]);
assert(!t[id].empty());
vi v;
int i = 0;
while (p < n && i < sz(t[id])) {
v.pb(t[id][i++]), v.pb(p++);
}
int ans = use_machine(v);
t[(ans & 1) ^ id].pb(p - 1);
if (id) res += (ans + 1) / 2;
else res += sz(v) / 2 - (ans + 1) / 2;
}
return res;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |