#include "mushrooms.h"
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define ld long double
#define ull unsigned long long
#define ff first
#define ss second
#define pii pair<int,int>
#define pll pair<long long, long long>
#define vi vector<int>
#define vl vector<long long>
#define pb push_back
#define rep(i, b) for(int i = 0; i < (b); ++i)
#define rep2(i,a,b) for(int i = a; i <= (b); ++i)
#define rep3(i,a,b,c) for(int i = a; i <= (b); i+=c)
#define count_bits(x) __builtin_popcountll((x))
#define all(x) (x).begin(),(x).end()
#define siz(x) (int)(x).size()
#define forall(it,x) for(auto& it:(x))
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
//mt19937 mt;void random_start(){mt.seed(chrono::time_point_cast<chrono::milliseconds>(chrono::high_resolution_clock::now()).time_since_epoch().count());}
//ll los(ll a, ll b) {return a + (mt() % (b-a+1));}
const int INF = 1e9+50;
const ll INF_L = 1e18+40;
const ll MOD = 1e9+7;
vector<bitset<7000>> queries[8];
int bit_ans[8];
int min_known[8];
pii dp[1001];
vi ones;
vi zeros;
set<int> not_know;
vi moves;
int sum1 = 1;
int N;
int get_un()
{
if(siz(not_know) == 0) return -1;
int v = *not_know.begin();
not_know.erase(not_know.find(v));
return v;
}
int get_sum(vi v)
{
if(siz(ones) >= siz(zeros))
{
vi query;
int cur_p = 0;
int un = (siz(ones) != 1 ? get_un() : -1);
if(un != -1) query.pb(un);
if(siz(ones) != 1) query.pb(ones[cur_p++]);
forall(it,v)
{
query.pb(it);
query.pb(ones[cur_p++]);
}
if(siz(query) > N)
{
int a = 0;
cout << 1/0 << "xdd\n";
assert("XD");
}
int s = use_machine(query);
if(un != -1)
{
if(s % 2 == 1)
{
zeros.pb(un);
s--;
}
else
{
ones.pb(un);
sum1++;
}
}
return (2*siz(v) - s)/2;
}
else
{
vi query;
int cur_p = 0;
int un = get_un();
if(un != -1) query.pb(un);
query.pb(zeros[cur_p++]);
forall(it,v)
{
query.pb(it);
query.pb(zeros[cur_p++]);
}
if(siz(query) > N)
{
int a = 0;
cout << 1/0 << "xdd\n";
assert("XD");
}
int s = use_machine(query);
if(un != -1)
{
if(s % 2 == 1)
{
s--;
ones.pb(un);
sum1++;
}
else
{
zeros.pb(un);
}
}
return s/2;
}
}
void get_queries(int k)
{
if(k == 0)
{
min_known[k] = 1;
bit_ans[k] = 1;
queries[k].resize(1);
queries[k][0] = 1;
return;
}
get_queries(k-1);
bit_ans[k] = 2*bit_ans[k-1] + (1 << (k-1))-1;
queries[k].resize((1 << k));
int cur_new = bit_ans[k-1]*2;
rep(i,(1 << (k-1))-1)
{
queries[k][i*2] = queries[k-1][i] | (queries[k-1][i] << bit_ans[k-1]);
queries[k][i*2+1] = queries[k-1][i] | ((queries[k-1].back() << bit_ans[k-1]) ^ (queries[k-1][i] << bit_ans[k-1]));
queries[k][i*2+1][cur_new] = 1;
cur_new++;
}
rep2(i,bit_ans[k-1],bit_ans[k-1]*2-1) queries[k][(1 << k)-2][i] = 1;
rep(i,bit_ans[k]) queries[k][(1 << k)-1][i] = 1;
vi sizes;
forall(it,queries[k]) sizes.pb(it.count());
sort(all(sizes));
min_known[k] = 0;
rep(i,siz(sizes))
{
min_known[k] = max(min_known[k],sizes[i]*2 + 1 - i);
}
}
int know_rest(int k, int rest)
{
int ans = 0;
k *= 2;
while(rest > 0)
{
ans++;
rest -= (k+1)/2;
k++;
}
return ans;
}
void get_vals_rek(vi p, vi anses, int k)
{
if(k == 0)
{
if(anses[0] == 1)
{
ones.pb(p[0]);
sum1++;
}
else
{
zeros.pb(p[0]);
}
return;
}
vi p_left;
vi p_right;
int right_sum = anses[(1 << k)-2];
int total_sum = anses[(1 << k)-1];
int left_sum = total_sum - right_sum;
vi left_anses;
vi right_anses;
rep(i,bit_ans[k-1]) p_left.pb(p[i]);
rep2(i,bit_ans[k-1],bit_ans[k-1]*2-1) p_right.pb(p[i]);
int cur_new = bit_ans[k-1]*2;
rep(i,(1 << (k-1))-1)
{
int a = anses[i*2];
int b = anses[i*2+1];
int new_ind = (a-b+right_sum+500000000) % 2;
if(new_ind == 1)
{
sum1++;
ones.pb(p[cur_new++]);
left_sum--;
}
else
{
zeros.pb(p[cur_new++]);
}
int right_ans = (a-b+right_sum+new_ind)/2;
int left_ans = (a+b-right_sum-new_ind)/2;
left_anses.pb(left_ans);
right_anses.pb(right_ans);
}
left_anses.pb(left_sum);
right_anses.pb(right_sum);
get_vals_rek(p_left,left_anses,k-1);
get_vals_rek(p_right,right_anses,k-1);
}
void get_vals(vi p, int k)
{
vi anses(siz(queries[k]));
vector<pii> query_sort;
rep(i,siz(queries[k]))
{
query_sort.pb({queries[k][i].count(),i});
}
sort(all(query_sort));
forall(it,query_sort)
{
vi q;
rep(i,7000)
{
if(queries[k][it.ss][i] == 1) q.pb(p[i]);
}
anses[it.ss] = get_sum(q);
}
get_vals_rek(p,anses,k);
}
void init(int n)
{
ones.pb(0);
get_queries(7);
dp[0] = {0,0};
dp[1] = {0,0};
pii total_min = {1e9,-1};
rep2(i,2,min(n,1000))
{
dp[i] = {1e9,-1};
rep2(k,0,7)
{
if(i - bit_ans[k] >= min_known[k])
{
dp[i] = min(dp[i],{dp[i - bit_ans[k]].ff+(1 << k),k});
}
}
total_min = min(total_min,{dp[i].ff + know_rest((dp[i].ff + i)/2,n-(dp[i].ff + i)),i});
}
int cur_elms = total_min.ss;
while(cur_elms != 0)
{
moves.pb(dp[cur_elms].ss);
cur_elms -= bit_ans[dp[cur_elms].ss];
}
sort(all(moves));
}
int count_mushrooms(int n)
{
N = n;
init(n);
rep2(i,1,n-1) not_know.insert(i);
int beg_q = 0;
forall(it,moves)
{
vi new_p;
rep(i,bit_ans[it])
{
new_p.pb(get_un());
}
get_vals(new_p,it);
beg_q += (1 << it);
}
while(siz(not_know) > 0)
{
int z = min({siz(not_know),max(siz(ones)-1,siz(zeros)-1)});
vi p;
rep(i,z)
{
p.pb(get_un());
}
sum1 += get_sum(p);
}
return sum1;
}
Compilation message (stderr)
mushrooms.cpp: In function 'int get_sum(std::vector<int>)':
mushrooms.cpp:67:34: warning: division by zero [-Wdiv-by-zero]
67 | cout << 1/0 << "xdd\n";
| ~^~
mushrooms.cpp:101:34: warning: division by zero [-Wdiv-by-zero]
101 | cout << 1/0 << "xdd\n";
| ~^~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |