#include "mushrooms.h"
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int, int>
#define F first
#define S second
#define all(c) ((c).begin()), ((c).end())
#define sz(x) ((int)(x).size())
#define ld long double
template<class T,class U>
ostream& operator<<(ostream& os,const pair<T,U>& p){
os<<"("<<p.first<<", "<<p.second<<")";
return os;
}
template<class T>
ostream& operator <<(ostream& os,const vector<T>& v){
os<<"{";
for(int i = 0;i < (int)v.size(); i++){
if(i)os<<", ";
os<<v[i];
}
os<<"}";
return os;
}
#ifdef LOCAL
#define cerr cout
#else
#endif
#define TRACE
#ifdef TRACE
#define trace(...) __f(#__VA_ARGS__, __VA_ARGS__)
template <typename Arg1>
void __f(const char* name, Arg1&& arg1){
cerr << name << " : " << arg1 << std::endl;
}
template <typename Arg1, typename... Args>
void __f(const char* names, Arg1&& arg1, Args&&... args){
const char* comma = strchr(names + 1, ',');cerr.write(names, comma - names) << " : " << arg1<<" | ";__f(comma+1, args...);
}
#else
#define trace(...)
#endif
const int M = 8;
int val[200005];
struct node{
vector<int> states;
node* res[M];
vector<int> ask;
int maxQueries;
node(){
for(int i = 0; i < M; i++) res[i] = NULL;
maxQueries = 0;
}
void compute(){
int n = states.size();
if(n == 1) return;
maxQueries = 1;
int mn = 1 << 20;
for(int mask = 1; mask < (1 << M); mask++){
vector<int> perm;
for(int j = 0; j < M; j++) if(mask >> j & 1) perm.push_back(j);
if(((int)perm.size()) == 1) continue;
do{
vector<int> num(M);
for(int s : states){
int r = 0;
for(int j = 0; j + 1 < (int)perm.size(); j++) r += (s >> perm[j] & 1) != (s >> perm[j + 1] & 1);
num[r]++;
}
int V = *max_element(num.begin(), num.end());
if(V < mn){
mn = V;
ask = perm;
}
}while(next_permutation(perm.begin(), perm.end()));
}
vector<int> perm = ask;
vector<vector<int>> childStates(M);
for(int s : states){
int r = 0;
for(int j = 0; j + 1 < (int)perm.size(); j++) r += (s >> perm[j] & 1) != (s >> perm[j + 1] & 1);
childStates[r].push_back(s);
}
for(int i = 0; i < M; i++){
if(!childStates[i].empty()){
res[i] = new node();
res[i]->states = childStates[i];
res[i]->compute();
maxQueries = max(maxQueries, res[i]->maxQueries + 1);
}
}
}
void get(vector<int> positions){
if((int)states.size() == 1){
for(int i = 0; i < M; i++) val[positions[i]] = states[0] >> i & 1;
return;
}
vector<int> x;
for(int i : ask) x.push_back(positions[i]);
res[use_machine(x)]->get(positions);
}
};
int get(int pos){
return val[pos] = use_machine({0, pos});
}
struct decision_tree{
node * root;
decision_tree(){
root = new node();
for(int i = 0; i < (1 << M); i+=2) root->states.push_back(i);
root->compute();
}
void get(vector<int> positions){
root->get(positions);
}
};
const int K = 6;
int count_mushrooms(int N) {
decision_tree DT;
int n = min(N, K);
for(int i = 1; i < n; i += M - 1){
int st = i, en = i + M - 2;
if(en < n){
vector<int> positions = {0};
for(int j = st; j <= en; j++) positions.push_back(j);
DT.get(positions);
} else{
for(int j = st; j < n; j++) get(j);
}
}
int curr = accumulate(val, val + n, 0);
if(n == N) return curr;
vector<vector<int>> where(2);
for(int i = 0; i < n; i++) where[val[i]].push_back(i);
int id = (int)where[1].size() > (int) where[0].size();
int R = where[id].size() - 1;
for(int i = n; i < N; i += R){
int st = i, en = min(N - 1, i + R - 1);
vector<int> positions = {where[id][0]};
for(int j = 0; j <= en - st; j++){
positions.push_back(st + j);
positions.push_back(where[id][j + 1]);
}
int V = use_machine(positions);
assert(V % 2 == 0);
V /= 2;
if(id == 1) curr += V;
else curr += en - st + 1 - V;
}
return curr;
}
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
616 ms |
512 KB |
Answer is not correct. |
2 |
Halted |
0 ms |
0 KB |
- |