#include <bits/stdc++.h>
#include "mushrooms.h"
#define ll long long
#define pb push_back
#define all(x) x.begin(),x.end()
#define sz(s) (int)s.size()
#define lb lower_bound
using namespace std;
const int MAX=2e4+100;
const int inf=1e9;
int count_mushrooms(int n){
vector<int> a={0},b;
int pos=1;
int C=105;
while(pos<n&&max(sz(a),sz(b))<C){
if(pos+4<n&&min(sz(a),sz(b))>1&&max(sz(a),sz(b))>2){
if(sz(a)>sz(b)){
int cnt=use_machine({a[0],pos,a[1],pos+1,a[2],pos+2});
if(cnt&1)b.pb(pos+2);
else a.pb(pos+2);
cnt/=2;
if(cnt==0){
a.pb(pos),a.pb(pos+1);
pos+=3;
}
else if(cnt==1){
int cnt1=use_machine({b[0],pos,b[1],a[0],pos+1,a[1],pos+3,a[2],pos+4})-1;
if(cnt1&1)b.pb(pos+4);
else a.pb(pos+4);
if(cnt1&2)b.pb(pos+3);
else a.pb(pos+3);
if(cnt1&4){
a.pb(pos);
b.pb(pos+1);
}
else{
a.pb(pos+1);
b.pb(pos);
}
pos+=5;
}
else{
assert(cnt==2);
b.pb(pos),b.pb(pos+1);
pos+=3;
}
}
else{
int cnt=use_machine({b[0],pos,b[1],pos+1,b[2],pos+2});
if(cnt&1)a.pb(pos+2);
else b.pb(pos+2);
cnt/=2;
if(cnt==0){
b.pb(pos),b.pb(pos+1);
pos+=3;
}
else if(cnt==1){
int cnt1=use_machine({a[0],pos,a[1],b[0],pos+1,b[1],pos+3,b[2],pos+4})-1;
if(cnt1&1)a.pb(pos+4);
else b.pb(pos+4);
if(cnt1&2)a.pb(pos+3);
else b.pb(pos+3);
if(cnt1&4){
a.pb(pos+1);
b.pb(pos);
}
else{
a.pb(pos);
b.pb(pos+1);
}
pos+=5;
}
else{
a.pb(pos),a.pb(pos+1);
pos+=3;
}
}
}
else if(pos+1<n&&max(sz(a),sz(b))>1){
if(sz(a)>=2){
int cnt=use_machine({pos,a[0],pos+1,a[1]});
if(cnt&1){
b.pb(pos);
}
else{
a.pb(pos);
}
cnt/=2;
if(cnt&1){
b.pb(pos+1);
}
else{
a.pb(pos+1);
}
}
else{
int cnt=use_machine({pos,b[0],pos+1,b[1]});
if(cnt&1){
a.pb(pos);
}
else{
b.pb(pos);
}
cnt/=2;
if(cnt&1){
a.pb(pos+1);
}
else{
b.pb(pos+1);
}
}
pos+=2;
}
else{
if(use_machine({pos,0}))b.pb(pos);
else a.pb(pos);
pos++;
}
}
sort(all(a));
int ans=sz(a);
for(int i=pos;i<n;){
if(sz(a)>sz(b)){
int R=min(n,i+sz(a));
vector<int> vect;
for(int j=i;j<min(n,i+sz(a)-1);j++){
vect.pb(a[j-i]);
vect.pb(j);
}
vect.pb(a.back());
if(i+sz(a)-1<n)vect.pb(i+sz(a)-1);
int cnt=use_machine(vect);
ans+=min(n,i+sz(a)-1)-cnt/2;
if(i+sz(a)-1<n){
if(cnt%2==1)b.pb(min(n-1,i+sz(a)-2));
else a.pb(min(n-1,i+sz(a)-2)),ans++;
}
i=R;
}
else{
int R=i+sz(b);
vector<int> vect;
for(int j=i;j<min(n,i+sz(b)-1);j++){
vect.pb(b[j-i]);
vect.pb(j);
}
vect.pb(b.back());
if(i+sz(b)-1<n)vect.pb(i+sz(b)-1);
int cnt=use_machine(vect);
ans+=cnt/2;
if(i+sz(b)-1<n){
if(cnt%2==1)a.pb(i+sz(b)-1),ans++;
else b.pb(i+sz(b)-1);
}
i=R;
}
}
return ans;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
0 ms |
344 KB |
Output is correct |
2 |
Correct |
0 ms |
344 KB |
Output is correct |
3 |
Correct |
0 ms |
344 KB |
Output is correct |
4 |
Correct |
0 ms |
344 KB |
Output is correct |
5 |
Correct |
1 ms |
344 KB |
Output is correct |
6 |
Incorrect |
1 ms |
344 KB |
Answer is not correct. |
7 |
Halted |
0 ms |
0 KB |
- |