제출 #311631

#제출 시각아이디문제언어결과실행 시간메모리
311631eriksuenderhauf버섯 세기 (IOI20_mushrooms)C++17
100 / 100
373 ms65800 KiB
#include "mushrooms.h"
#include <bits/stdc++.h>
using namespace std;
const int inf = 1e9 + 7;
int dp[255][255][255];

int go(int known_a, int known_b, int qcnt);

pair<int,int> max_val(int known_a, int known_b, int qcnt) {
  int v = 0, operation = 2;
  // A.A.(...)
  // B.B.(...)
  for (int i : {2, known_a}) {
    if (i > known_a)
      continue;
    int opt = inf;
    for (int b = 0; b < 2 && opt > v; b++) {
      if (i == 2) {
        for (int a = 0; a < 2 && opt > v; a++)
          opt = min(opt, i + go(known_a + (b ^ 1) + a, known_b + b + (a ^ 1), qcnt - 1));
      } else {
        opt = min(opt, i + go(known_a + (b ^ 1), known_b + b, qcnt - 1));
      }
    }
    if (opt > v) {
      v = opt;
      operation = i == 2 ? 1 : 2;
    }
  }

  int ok = 1;
  // 000001111100
  if (qcnt > 3) {
    int cnt = 12, qry = 4;
    if (known_a >= 7 && known_b >= 5) {
      int opt = inf;
      for (int a = 0; a <= cnt && opt > v; a++) {
        if (qcnt > qry && cnt + known_b <= known_a && a != 0)
          break;
        int b = cnt - a;
        opt = min(opt, cnt + go(known_a + a, known_b + b, qcnt - qry));
      }
      if (opt > v) {
        v = opt;
        operation = 3;
      }
      ok = 0;
    }
  }
  // 00011
  if (qcnt > 1 && ok) {
    int cnt = 5, qry = 2;
    if (known_a >= 3 && known_b >= 2) {
      int opt = inf;
      for (int a = 0; a <= cnt; a++) {
        if (qcnt > qry && cnt + known_b <= known_a && a != 0)
          break;
        int b = cnt - a;
        opt = min(opt, cnt + go(known_a + a, known_b + b, qcnt - qry));
      }
      if (opt > v) {
        v = opt;
        operation = 4;
      }
    }
  }
  return make_pair(v, operation);
}

int go(int known_a, int known_b, int qcnt) {
  if (known_b > known_a)
    swap(known_a, known_b);
  if (qcnt == 0)
    return 0;
  if (known_a > 250)
    return -inf;
  if (~dp[known_a][known_b][qcnt])
    return dp[known_a][known_b][qcnt];
  int v = max_val(known_a, known_b, qcnt).first;
  return dp[known_a][known_b][qcnt] = v;
}

vector<int> ind_a, ind_b;
int cnt_a = 0, cnt_b = 0;

void flip(int fl) {
  if (fl) {
    swap(ind_a, ind_b);
    swap(cnt_a, cnt_b);
  }
}

void askA(vector<int>& x) {
  int len = int(x.size());
  assert(int(ind_a.size()) >= len);
  vector<int> cur;
  for (int i = 0; i < len; i++)
    cur.push_back(ind_a[i]), cur.push_back(x[i]);
  int ret = use_machine(cur);
  if (ret & 1) {
    ind_b.push_back(x.back());
    cnt_b++, ret--, len--;
  } else {
    ind_a.push_back(x.back());
    cnt_a++, len--;
  }
  cnt_b += ret / 2;
  cnt_a += len - (ret / 2);
  if (ret == 0) {
    for (int i = 0; i < len; i++)
      ind_a.push_back(x[i]);
  } else if (ret == len * 2) {
    for (int i = 0; i < len; i++)
      ind_b.push_back(x[i]);
  }
}

const int qcnt_big = 4, len_big = 12;
vector<int> ord_big[qcnt_big] = {
  {8, 0, 9, 10, 7, 11, 5, 1, 6, 4, 2, 3},
  {3, 8, 4, 2, 1, 11, 5, 10, 6, 0, 9, 7},
  {5, 7, 10, 1, 9, 3, 8, 6, 2, 11, 4, 0},
  {2, 3, 11, 4, 6, 1, 9, 8, 7, 0, 5, 10}
};
vector<int> msk_big = {0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0};
map<array<int,qcnt_big>,int> rev_big;

void askBig(vector<int>& x) {
  assert(int(x.size()) == len_big);
  int fl = int(ind_a.size()) < int(ind_b.size());
  flip(fl);
  array<int,qcnt_big> arr;
  for (int i = 0; i < qcnt_big; i++) {
    vector<int> x_ord;
    for (int j = 0, ja = 0, jb = 0; j < len_big; j++) {
      x_ord.push_back(msk_big[j] == 0 ? ind_a[ja++] : ind_b[jb++]);
      x_ord.push_back(x[ord_big[i][j]]);
    }
    arr[i] = use_machine(x_ord);
  }
  int ans = rev_big[arr];
  cnt_b += __builtin_popcount(ans);
  cnt_a += len_big - __builtin_popcount(ans);
  for (int i = 0; i < len_big; i++)
    if ((ans >> i) & 1)
      ind_b.push_back(x[i]);
    else
      ind_a.push_back(x[i]);
  flip(fl);
}

const int qcnt_small = 2, len_small = 5;
vector<int> ord_small[qcnt_small] = {
  {2, 3, 1, 0, 4},
  {1, 2, 0, 4, 3}
};
vector<int> msk_small = {1, 1, 0, 0, 0};
map<array<int,qcnt_small>,int> rev_small;

void askSmall(vector<int>& x) {
  assert(int(x.size()) == len_small);
  int fl = int(ind_a.size()) < int(ind_b.size());
  flip(fl);
  array<int,qcnt_small> arr;
  for (int i = 0; i < qcnt_small; i++) {
    vector<int> x_ord;
    for (int j = 0, ja = 0, jb = 0; j < len_small; j++) {
      x_ord.push_back(msk_small[j] == 0 ? ind_a[ja++] : ind_b[jb++]);
      x_ord.push_back(x[ord_small[i][j]]);
    }
    arr[i] = use_machine(x_ord);
  }
  int ans = rev_small[arr];
  cnt_b += __builtin_popcount(ans);
  cnt_a += len_small - __builtin_popcount(ans);
  for (int i = 0; i < len_small; i++)
    if ((ans >> i) & 1)
      ind_b.push_back(x[i]);
    else
      ind_a.push_back(x[i]);
  flip(fl);
}

void build() {
  auto eval = [&](vector<int>& a, vector<int>& b) {
    assert(int(a.size()) == int(b.size()));
    int r = 0;
    for (int i = 0, j = int(a.size()); i < j; i++)
      r += int(a[i] != b[i]) + int(i + 1 != j ? b[i] != a[i+1] : 0);
    return r;
  };
  for (int msk = 0; msk < (1 << len_small); msk++) {
    array<int,qcnt_small> arr;
    for (int it = 0; it < qcnt_small; it++) {
      vector<int> ord;
      for (int j = 0; j < len_small; j++)
        ord.push_back((msk >> ord_small[it][j]) & 1);
      arr[it] = eval(msk_small, ord);
    }
    rev_small[arr] = msk;
  }
  for (int msk = 0; msk < (1 << len_big); msk++) {
    array<int,qcnt_big> arr;
    for (int it = 0; it < qcnt_big; it++) {
      vector<int> ord;
      for (int j = 0; j < len_big; j++)
        ord.push_back((msk >> ord_big[it][j]) & 1);
      arr[it] = eval(msk_big, ord);
    }
    rev_big[arr] = msk;
  }
}

void solve(int known_a, int known_b, int qcnt, int n) {
  assert(qcnt >= 0);
  int mx_unused = cnt_a + cnt_b;
  if (mx_unused >= n)
    return;
  int fl = int(known_a < known_b);
  if (fl)
    swap(known_a, known_b);
  flip(fl);
  int operation = max_val(known_a, known_b, qcnt).second;
  vector<int> x;
  switch (operation) {
    case 1: {
      for (int i = 0; i < 2 && mx_unused + i < n; i++)
        x.push_back(mx_unused + i);
      askA(x);
      qcnt--;
      break;
    }
    case 2: {
      for (int i = 0; i < known_a && mx_unused + i < n; i++)
        x.push_back(mx_unused + i);
      askA(x);
      qcnt--;
      break;
    }
    case 3: {
      for (int i = 0; i < len_big && mx_unused + i < n; i++)
        x.push_back(mx_unused + i);
      if (int(x.size()) != len_big) {
        askA(x);
        qcnt--;
      } else {
        askBig(x);
        qcnt -= qcnt_big;
      }
      break;
    }
    case 4: {
      for (int i = 0; i < len_small && mx_unused + i < n; i++)
        x.push_back(mx_unused + i);
      if (int(x.size()) != len_small) {
        askA(x);
        qcnt--;
      } else {
        askSmall(x);
        qcnt -= qcnt_small;
      }
      break;
    }
  }
  flip(fl);
  solve(int(ind_a.size()), int(ind_b.size()), qcnt, n);
}

int count_mushrooms(int n) {
  memset(dp, -1, sizeof dp);
  go(1, 0, 228);
  build();
  ind_a.push_back(0);
  cnt_a = 1;
  solve(1, 0, 228, n);
  return cnt_a;
}
#Verdict Execution timeMemoryGrader output
Fetching results...