Submission #463677

#TimeUsernameProblemLanguageResultExecution timeMemory
463677rainboyCounting Mushrooms (IOI20_mushrooms)C++17
0 / 100
1 ms456 KiB
/* https://codeforces.com/blog/entry/82924 */
/* https://loj.ac/s/938329 (whzzt) */
#include "mushrooms.h"
#include <stdlib.h>
#include <string.h>

using namespace std;

typedef vector<int> vi;

int min(int a, int b) { return a < b ? a : b; }
int max(int a, int b) { return a > b ? a : b; }

const int N = 20000, Q = 203, LG = 7;	/* LG = floor(log2(Q)) */

unsigned int X = 12345;

int rand_() {
	return (X *= 3) >> 1;
}

char **mask[LG]; int *cnt[LG], *hhh[LG], nn[LG], lower[LG], dp[Q + 1], lg_[Q + 1];

int l_;

void sort(int *hh, int l, int r) {
	while (l < r) {
		int i = l, j = l, k = r, h = hh[l + rand_() % (r - l)], tmp;

		while (j < k)
			if (cnt[l_][hh[j]] == cnt[l_][h])
				j++;
			else if (cnt[l_][hh[j]] < cnt[l_][h]) {
				tmp = hh[i], hh[i] = hh[j], hh[j] = tmp;
				i++, j++;
			} else {
				k--;
				tmp = hh[j], hh[j] = hh[k], hh[k] = tmp;
			}
		sort(hh, l, i);
		l = k;
	}
}

void init() {
	int n, n_, q, q_, h, h0, h1, i, lg;

	for (lg = 0; lg <= LG; lg++) {
		q = 1 << lg, nn[lg] = lg == 0 ? 1 : nn[lg - 1] * 2 + q / 2 - 1;
		mask[lg] = (char **) malloc(q * sizeof *mask[lg]);
		for (h = 0; h < q; h++)
			mask[lg][h] = (char *) calloc(nn[lg], sizeof *mask[lg][h]);
		cnt[lg] = (int *) malloc(q * sizeof *cnt[lg]);
		hhh[lg] = (int *) malloc(q * sizeof *hhh[lg]);
	}
	cnt[0][0] = 1, mask[0][0][0] = 1, hhh[0][0] = 0;
	lower[0] = 3;
	for (lg = 0; lg < LG; lg++) {
		n = nn[lg], n_ = nn[lg + 1], q = 1 << lg, q_ = 1 << lg + 1;
		for (h = 0; h + 1 < q; h++) {
			h0 = h << 1 | 0, h1 = h << 1 | 1;
			cnt[lg + 1][h0] = cnt[lg][h] + cnt[lg][h];
			cnt[lg + 1][h1] = cnt[lg][h] + (n - cnt[lg][h]) + 1;
			for (i = 0; i < n; i++) {
				mask[lg + 1][h0][i] = mask[lg][h][i], mask[lg + 1][h0][n + i] = mask[lg][h][i];
				mask[lg + 1][h1][i] = mask[lg][h][i], mask[lg + 1][h1][n + i] = mask[lg][h][i] ^ 1;
			}
			mask[lg + 1][h1][n * 2 + h] = 1;
		}
		h0 = h << 1 | 0, h1 = h << 1 | 1;
		cnt[lg + 1][h0] = n_ - n, cnt[lg + 1][h1] = n_;
		for (i = 0; i < n; i++) {
			mask[lg + 1][h0][i] = 1, mask[lg + 1][h0][n + i] = 0;
			mask[lg + 1][h1][i] = 1, mask[lg + 1][h1][n + i] = 1;
		}
		for (h = 0; h + 1 < q; h++) {
			mask[lg + 1][h0][n * 2 + h] = 1;
			mask[lg + 1][h1][n * 2 + h] = 1;
		}
		for (h = 0; h < q_; h++)
			hhh[lg + 1][h] = h;
		cnt[lg + 1][q_ - 1] -= cnt[lg + 1][q_ - 2];
		l_ = lg + 1, sort(hhh[lg + 1], 0, q_);
		for (h = 0; h < q_; h++)
			lower[lg + 1] = max(lower[lg + 1], cnt[lg + 1][hhh[lg + 1][h]] * 2 + 1 - h);
		cnt[lg + 1][q_ - 1] += cnt[lg + 1][q_ - 2];
	}
	for (lg = 1; lg <= LG; lg++) {
		cnt[lg][(1 << lg) - 1] -= cnt[lg][(1 << lg) - 2];
		for (i = 0; i < nn[lg]; i++)
			mask[lg][(1 << lg) - 1][i] ^= mask[lg][(1 << lg) - 2][i];
	}
	for (q = 0; q < 3; q++)
		dp[q] = q + 1, lg_[q] = -1;
	for (q = 2; q <= Q; q++)
		for (lg = 0; lg <= LG && (q + (1 << lg)) <= Q; lg++)
			if (dp[q] >= lower[lg]) {
				n = dp[q] + nn[lg] + (1 << lg);
				if (dp[q + (1 << lg)] < n)
					dp[q + (1 << lg)] = n, lg_[q + (1 << lg)] = lg;
			}
}

vi ii_;
int ii[2][N], kk[2];

int sum(char *msk, int l, int r, int i_) {
	int t, i, k, x;

	t = kk[0] > kk[1] ? 0 : 1;
	ii_.clear(), k = 0;
	if (i_ != -1)
		ii_.push_back(i_), ii_.push_back(ii[t][k++]);
	for (i = l; i < r; i++)
		if (i_ == -1 || msk[i - l])
			ii_.push_back(i), ii_.push_back(ii[t][k++]);
	x = use_machine(ii_);
	if (t == 1)
		x = k * 2 - 1 - x;
	ii[x % 2][kk[x % 2]++] = i_ == -1 ? l : i_;
	return x / 2;
}

int ss[N], tt[N];

void solve(int *ss, int lg, int l) {
	int i, a, c;

	if (lg == 0) {
		ii[ss[0]][kk[ss[0]]++] = l;
		return;
	}
	a = ss[(1 << lg) - 2], c = ss[(1 << lg) - 1] - a;
	for (i = 0; i + 1 < 1 << lg - 1; i++) {
		int i0 = i << 1 | 0, i1 = i << 1 | 1, t;

		t = tt[(1 << lg) + i] = (ss[i0] + ss[i1] + c) % 2;
		if (t == 1)
			a--;
		tt[i] = (ss[i0] + ss[i1] - t - c) / 2, tt[(1 << lg - 1) + i] = (ss[i0] - ss[i1] + t + c) / 2;
		ii[t][kk[t]++] = l + nn[lg - 1] * 2 + i;
	}
	tt[(1 << lg - 1) - 1] = a, tt[(1 << lg) - 1] = c;
	memcpy(ss, tt, (1 << lg) * sizeof *ss);
	solve(ss, lg - 1, l), solve(ss + (1 << lg - 1), lg - 1, l + nn[lg - 1]);
}

void trace(int q) {
	int h, h_, i, l, lg;

	if (q == 0) {
		ii[0][kk[0]++] = 0;
		return;
	}
	if (q <= 2) {
		int x;

		trace(q - 1);
		ii_.resize(2), ii_[0] = 0, ii_[1] = q, x = use_machine(ii_);
		ii[x][kk[x]++] = q;
		return;
	}
	lg = lg_[q];
	trace(q - (1 << lg));
	l = dp[q - (1 << lg)];
	for (h = 0; h < 1 << lg; h++) {
		h_ = hhh[lg][h];
		ss[h_] = sum(mask[lg][h_], l, l + nn[lg], l + nn[lg] + h);
	}
	if (lg > 0)
		ss[(1 << lg) - 1] += ss[(1 << lg) - 2];
	solve(ss, lg, l);
}

int count_mushrooms(int n) {
	int q_, q, l, r, ans;

	init();
	for (q_ = 0; q_ <= Q; q_++) {
		int n_ = dp[q_];

		for (q = 0; q < Q - q_; q++)
			n_ += (dp[q_] + q + 1) / 2;
		if (n_ >= n)
			break;
	}
	trace(q_);
	ans = n;
	for (l = dp[q_]; l < n; l = r)
		ans -= sum(NULL, l, r = min(l + max(kk[0], kk[1]), n), -1);
	ans -= kk[1];
	return ans;
}

Compilation message (stderr)

mushrooms.cpp: In function 'void init()':
mushrooms.cpp:59:58: warning: suggest parentheses around '+' inside '<<' [-Wparentheses]
   59 |   n = nn[lg], n_ = nn[lg + 1], q = 1 << lg, q_ = 1 << lg + 1;
      |                                                       ~~~^~~
mushrooms.cpp: In function 'void solve(int*, int, int)':
mushrooms.cpp:134:30: warning: suggest parentheses around '-' inside '<<' [-Wparentheses]
  134 |  for (i = 0; i + 1 < 1 << lg - 1; i++) {
      |                           ~~~^~~
mushrooms.cpp:140:54: warning: suggest parentheses around '-' inside '<<' [-Wparentheses]
  140 |   tt[i] = (ss[i0] + ss[i1] - t - c) / 2, tt[(1 << lg - 1) + i] = (ss[i0] - ss[i1] + t + c) / 2;
      |                                                   ~~~^~~
mushrooms.cpp:143:14: warning: suggest parentheses around '-' inside '<<' [-Wparentheses]
  143 |  tt[(1 << lg - 1) - 1] = a, tt[(1 << lg) - 1] = c;
      |           ~~~^~~
mushrooms.cpp:145:44: warning: suggest parentheses around '-' inside '<<' [-Wparentheses]
  145 |  solve(ss, lg - 1, l), solve(ss + (1 << lg - 1), lg - 1, l + nn[lg - 1]);
      |                                         ~~~^~~
mushrooms.cpp: In function 'void trace(int)':
mushrooms.cpp:149:13: warning: unused variable 'i' [-Wunused-variable]
  149 |  int h, h_, i, l, lg;
      |             ^
mushrooms.cpp: In function 'void init()':
mushrooms.cpp:53:11: warning: iteration 7 invokes undefined behavior [-Waggressive-loop-optimizations]
   53 |   cnt[lg] = (int *) malloc(q * sizeof *cnt[lg]);
      |   ~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
mushrooms.cpp:48:18: note: within this loop
   48 |  for (lg = 0; lg <= LG; lg++) {
      |               ~~~^~~~~
mushrooms.cpp:86:13: warning: iteration 6 invokes undefined behavior [-Waggressive-loop-optimizations]
   86 |   cnt[lg + 1][q_ - 1] += cnt[lg + 1][q_ - 2];
      |   ~~~~~~~~~~^
mushrooms.cpp:58:18: note: within this loop
   58 |  for (lg = 0; lg < LG; lg++) {
      |               ~~~^~~~
mushrooms.cpp:90:24: warning: iteration 6 invokes undefined behavior [-Waggressive-loop-optimizations]
   90 |   for (i = 0; i < nn[lg]; i++)
      |                   ~~~~~^
mushrooms.cpp:88:18: note: within this loop
   88 |  for (lg = 1; lg <= LG; lg++) {
      |               ~~~^~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...