Submission #1212312

#TimeUsernameProblemLanguageResultExecution timeMemory
1212312Zbyszek99Counting Mushrooms (IOI20_mushrooms)C++20
100 / 100
7 ms1604 KiB
#include "mushrooms.h"
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define ll long long
#define ld long double
#define ull unsigned long long
#define ff first
#define ss second
#define pii pair<int,int>
#define pll pair<long long, long long>
#define vi vector<int>
#define vl vector<long long>
#define pb push_back
#define rep(i, b) for(int i = 0; i < (b); ++i)
#define rep2(i,a,b) for(int i = a; i <= (b); ++i)
#define rep3(i,a,b,c) for(int i = a; i <= (b); i+=c)
#define count_bits(x) __builtin_popcountll((x))
#define all(x) (x).begin(),(x).end()
#define siz(x) (int)(x).size()
#define forall(it,x) for(auto& it:(x))
using namespace __gnu_pbds;
using namespace std;
typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
//mt19937 mt;void random_start(){mt.seed(chrono::time_point_cast<chrono::milliseconds>(chrono::high_resolution_clock::now()).time_since_epoch().count());}
//ll los(ll a, ll b) {return a + (mt() % (b-a+1));}
const int INF = 1e9+50;
const ll INF_L = 1e18+40;
const ll MOD = 1e9+7;

vector<bitset<7000>> queries[8];
int bit_ans[8];
int min_known[8];
pii dp[1001];
vi ones;
vi zeros;
set<int> not_know;
vi moves;
int sum1 = 1;

int get_un()
{
	if(siz(not_know) == 0) return -1;
	int v = *not_know.begin();
	not_know.erase(not_know.find(v));
	return v;
}

int get_sum(vi v)
{
	if(siz(ones) >= siz(zeros))
	{
		vi query;
		int cur_p = 0;
		int un = (siz(ones) != 1 ? get_un() : -1);
		if(un != -1) query.pb(un);
		if(siz(ones) != 1) query.pb(ones[cur_p++]);
		forall(it,v)
		{
			query.pb(it);
			query.pb(ones[cur_p++]);
		}
		int s = use_machine(query);
		if(un != -1)
		{
			if(s % 2 == 1)
			{
				zeros.pb(un);
				s--;
			}
			else
			{
				ones.pb(un);
				sum1++;
			}
		}
		return (2*siz(v) - s)/2;
	}
	else
	{
		vi query;
		int cur_p = 0;
		int un = get_un();
		if(un != -1) query.pb(un);
		query.pb(zeros[cur_p++]);
		forall(it,v)
		{
			query.pb(it);
			query.pb(zeros[cur_p++]);
		}
		int s = use_machine(query);
		if(un != -1)
		{
			if(s % 2 == 1)
			{
				s--;
				ones.pb(un);
				sum1++;
			}
			else
			{
				zeros.pb(un);
			}
		}
		return s/2;
	}
}

void get_queries(int k)
{
	if(k == 0)
	{
		min_known[k] = 1;
		bit_ans[k] = 1;
		queries[k].resize(1);
		queries[k][0] = 1;
		return;
	}
	get_queries(k-1);
	bit_ans[k] = 2*bit_ans[k-1] + (1 << (k-1))-1;
	queries[k].resize((1 << k));
	int cur_new = bit_ans[k-1]*2;
	rep(i,(1 << (k-1))-1)
	{
		queries[k][i*2] = queries[k-1][i] | (queries[k-1][i] << bit_ans[k-1]);	
		queries[k][i*2+1] = queries[k-1][i] | ((queries[k-1].back() << bit_ans[k-1]) ^ (queries[k-1][i] << bit_ans[k-1]));
		queries[k][i*2+1][cur_new] = 1;
		cur_new++;
	}
	rep2(i,bit_ans[k-1],bit_ans[k-1]*2-1) queries[k][(1 << k)-2][i] = 1;
	rep(i,bit_ans[k]) queries[k][(1 << k)-1][i] = 1;
	vi sizes;
	forall(it,queries[k]) sizes.pb(it.count());
	sort(all(sizes));
	min_known[k] = 0;
	rep(i,siz(sizes))
	{
		min_known[k] = max(min_known[k],sizes[i]*2 + 1 - i);
	}
}

int know_rest(int k, int rest)
{
	int ans = 0;
	k *= 2;
	while(rest > 0)
	{
		ans++;
		rest -= (k+1)/2;
		k++;
	}
	return ans;
}

void get_vals_rek(vi p, vi anses, int k)
{
	if(k == 0)
	{
		if(anses[0] == 1)
		{
			ones.pb(p[0]);
			sum1++;
		}
		else
		{
			zeros.pb(p[0]);
		}
		return;
	}
	vi p_left;
	vi p_right;
	int right_sum = anses[(1 << k)-2];
	int total_sum = anses[(1 << k)-1];
	int left_sum = total_sum - right_sum;
	vi left_anses;
	vi right_anses;
	rep(i,bit_ans[k-1]) p_left.pb(p[i]);
	rep2(i,bit_ans[k-1],bit_ans[k-1]*2-1) p_right.pb(p[i]);
	int cur_new = bit_ans[k-1]*2;
	rep(i,(1 << (k-1))-1)
	{
		int a = anses[i*2];
		int b = anses[i*2+1];
		int new_ind = (a-b+right_sum+500000000) % 2;
		if(new_ind == 1)
		{
			sum1++;
			ones.pb(p[cur_new++]);
			left_sum--;
		}
		else
		{
			zeros.pb(p[cur_new++]);
		}
		int right_ans = (a-b+right_sum+new_ind)/2;
		int left_ans = (a+b-right_sum-new_ind)/2;
		left_anses.pb(left_ans);
		right_anses.pb(right_ans);
	}
	left_anses.pb(left_sum);
	right_anses.pb(right_sum);
	get_vals_rek(p_left,left_anses,k-1);
	get_vals_rek(p_right,right_anses,k-1);
}

void get_vals(vi p, int k)
{
	vi anses(siz(queries[k]));
	vector<pii> query_sort;
	rep(i,siz(queries[k]))
	{
		query_sort.pb({queries[k][i].count(),i});
	}
	sort(all(query_sort));
	forall(it,query_sort)
	{
		vi q;
		rep(i,7000)
		{
			if(queries[k][it.ss][i] == 1) q.pb(p[i]);
		}
		anses[it.ss] = get_sum(q);
	}
	get_vals_rek(p,anses,k);
}

void init(int n)
{
	ones.pb(0);
	get_queries(7);
	dp[0] = {0,0};
	dp[1] = {0,0};
	pii total_min = {1e9,1};
	rep2(i,2,min(n,1000))
	{
		dp[i] = {1e9,-1};
		rep2(k,0,7)
		{
			if(i - bit_ans[k] >= min_known[k])
			{
				dp[i] = min(dp[i],{dp[i - bit_ans[k]].ff+(1 << k),k});
			}
		}
		total_min = min(total_min,{dp[i].ff + know_rest((dp[i].ff + i)/2,n-(dp[i].ff + i)),i});
	}
	int cur_elms = total_min.ss;
	while(cur_elms != 0)
	{
		moves.pb(dp[cur_elms].ss);
		cur_elms -= bit_ans[dp[cur_elms].ss];
	}
	sort(all(moves));
}

int count_mushrooms(int n) 
{
	init(n-1);
	rep2(i,1,n-1) not_know.insert(i);
	forall(it,moves)
	{
		vi new_p;
		rep(i,bit_ans[it])
		{
			new_p.pb(get_un());
		}
		get_vals(new_p,it);
	}
	while(siz(not_know) > 0)
	{
		int z = min({siz(not_know),max(siz(ones)-1,siz(zeros)-1)});
		vi p;
		rep(i,z)
		{
			p.pb(get_un());
		}
		sum1 += get_sum(p);
	}
	return sum1;
}
#Verdict Execution timeMemoryGrader output
Fetching results...