Submission #303599

#TimeUsernameProblemLanguageResultExecution timeMemory
303599jtnydv25Counting Mushrooms (IOI20_mushrooms)C++17
81.59 / 100
943 ms656 KiB
#include "mushrooms.h" #include <bits/stdc++.h> using namespace std; const int M = 8; int val[200005]; struct node{ vector<int> states; node* res[M]; vector<int> ask; int maxQueries; node(){ for(int i = 0; i < M; i++) res[i] = NULL; maxQueries = 0; } void compute(){ int n = states.size(); if(n == 1) return; maxQueries = 1; int mn = 1 << 20; for(int mask = 1; mask < (1 << M); mask++){ vector<int> perm; for(int j = 0; j < M; j++) if(mask >> j & 1) perm.push_back(j); if(((int)perm.size()) == 1) continue; do{ vector<int> num(M); for(int s : states){ int r = 0; for(int j = 0; j + 1 < (int)perm.size(); j++) r += (s >> perm[j] & 1) != (s >> perm[j + 1] & 1); num[r]++; } int V = *max_element(num.begin(), num.end()); if(V < mn){ mn = V; ask = perm; } }while(next_permutation(perm.begin(), perm.end())); } vector<int> perm = ask; vector<vector<int>> childStates(M); for(int s : states){ int r = 0; for(int j = 0; j + 1 < (int)perm.size(); j++) r += (s >> perm[j] & 1) != (s >> perm[j + 1] & 1); childStates[r].push_back(s); } for(int i = 0; i < M; i++){ if(!childStates[i].empty()){ res[i] = new node(); res[i]->states = childStates[i]; res[i]->compute(); maxQueries = max(maxQueries, res[i]->maxQueries + 1); } } } void get(vector<int> positions){ if((int)states.size() == 1){ for(int i = 0; i < M; i++) val[positions[i]] = states[0] >> i & 1; return; } vector<int> x; for(int i : ask) x.push_back(positions[i]); res[use_machine(x)]->get(positions); } }; int get(int pos){ return val[pos] = use_machine({0, pos}); } struct decision_tree{ node * root; decision_tree(){ root = new node(); for(int i = 0; i < (1 << M); i+=2) root->states.push_back(i); root->compute(); } void get(vector<int> positions){ root->get(positions); } }; const int K = 400; int count_mushrooms(int N) { decision_tree DT; int n = min(N, K); // int n = N; // for(int j = 1; j < N; j++) get(j); for(int i = 1; i < n; i += M - 1){ int st = i, en = i + M - 2; if(en < n){ vector<int> positions = {0}; for(int j = st; j <= en; j++) positions.push_back(j); DT.get(positions); } else{ for(int j = st; j < n; j++) get(j); } } int curr = n - accumulate(val, val + n, 0); if(n == N) return curr; vector<vector<int>> where(2); for(int i = 0; i < n; i++) where[val[i]].push_back(i); int id = (int)where[1].size() > (int) where[0].size(); int R = where[id].size() - 1; for(int i = n; i < N; i += R){ int st = i, en = min(N - 1, i + R - 1); vector<int> positions = {where[id][0]}; for(int j = 0; j <= en - st; j++){ positions.push_back(st + j); positions.push_back(where[id][j + 1]); } int V = use_machine(positions); assert(V % 2 == 0); V /= 2; if(id == 1) curr += V; else curr += en - st + 1 - V; } return curr; }
#Verdict Execution timeMemoryGrader output
Fetching results...