Submission #749185

# Submission time Handle Problem Language Result Execution time Memory
749185 2023-05-27T13:22:00 Z zengminghao Alice, Bob, and Circuit (APIO23_abc) C++17
100 / 100
1278 ms 421652 KB
#include "abc.h"

// you may find the definitions useful
const int OP_ZERO    = 0;  // f(OP_ZERO,    x0, x1) = 0
const int OP_NOR     = 1;  // f(OP_NOR,     x0, x1) = !(x0 || x1)
const int OP_GREATER = 2;  // f(OP_GREATER, x0, x1) = (x0 > x1)
const int OP_NOT_X1  = 3;  // f(OP_NOT_X1,  x0, x1) = !x1
const int OP_LESS    = 4;  // f(OP_LESS,    x0, x1) = (x0 < x1)
const int OP_NOT_X0  = 5;  // f(OP_NOT_X0,  x0, x1) = !x0
const int OP_XOR     = 6;  // f(OP_XOR,     x0, x1) = (x0 ^ x1)
const int OP_NAND    = 7;  // f(OP_NAND,    x0, x1) = !(x0 && x1)
const int OP_AND     = 8;  // f(OP_AND,     x0, x1) = (x0 && x1)
const int OP_EQUAL   = 9;  // f(OP_EQUAL,   x0, x1) = (x0 == x1)
const int OP_X0      = 10; // f(OP_X0,      x0, x1) = x0
const int OP_GEQ     = 11; // f(OP_GEQ,     x0, x1) = (x0 >= x1)
const int OP_X1      = 12; // f(OP_X1,      x0, x1) = x1
const int OP_LEQ     = 13; // f(OP_LEQ,     x0, x1) = (x0 <= x1)
const int OP_OR      = 14; // f(OP_OR,      x0, x1) = (x0 || x1)
const int OP_ONE     = 15; // f(OP_ONE,     x0, x1) = 1

#include <bits/stdc++.h>
using namespace std;
typedef vector<int> vi;

// encodes 4-letter name to 19 bits
int encode(string s) {
	int res = 0;
	for (int i = 0; i < (int)s.size(); i++)
		res = res * 26 + (s[i] - 'a');
	int base = 0;
	if ((int)s.size() > 1) base += 26;
	if ((int)s.size() > 2) base += 26 * 26;
	if ((int)s.size() > 3) base += 26 * 26 * 26;
	return res + base;
}

vector<pair<int, int>> sorting_network(int n) {
	vector<pair<int, int>> res;
	auto add = [&](int x, int y) {
		res.push_back(make_pair(x - 1, y - 1));
	};
	const int s16[20][20] = {
	{0,13,1,12,2,15,3,14,4,8,5,6,7,11,9,10},
	{0,5,1,7,2,9,3,4,6,13,8,14,10,15,11,12},
	{0,1,2,3,4,5,6,8,7,9,10,11,12,13,14,15},
	{0,2,1,3,4,10,5,11,6,7,8,9,12,14,13,15},
	{1,2,3,12,4,6,5,7,8,10,9,11,13,14},
	{1,4,2,6,5,8,7,10,9,13,11,14},
	{2,4,3,6,9,12,11,13},
	{3,5,6,8,7,9,10,12},
	{3,4,5,6,7,8,9,10,11,12},
	{6,7,8,9},
	};
	for (int t = 0; t < 10; t++)
		for (int i = 1; i <= n; i += 16)
			for (int j = 0; s16[t][j + 1]; j += 2)
				if (i + s16[t][j + 1] <= n)
					add(i + s16[t][j], i + s16[t][j + 1]);
	for (int p = 16; p < n; p *= 2)
		for (int k = p; k >= 1; k /= 2)
			for (int j = k % p; j <= n - k - 1; j += 2 * k)
				for (int i = 0; i <= min(k - 1, n - j - k - 1); i++)
					if ((i + j) / (p * 2) == (i + j + k) / (p * 2))
						add(i + j + 1, i + j + k + 1);
	return res;
}

vector<int> operator + (vector<int> a, vector<int> b) {
	vector<int> c;
	for (int i : a) c.push_back(i);
	for (int i : b) c.push_back(i);
	return c;
}

vector<int> permute(vector<int> a, vector<int> b) {
	int n = a.size();
	if (n == 1) return {};
	
	map<int, int> revb;
	for (int i = 0; i < n; i++) revb[b[i]] = i;
	
	vector<vector<int>> G(n);
	vector<int> vis(n), match(n / 2);
	for (int i = 0; i < n; i++) {
		int u = i / 2;
		int v = n / 2 + revb[a[i]] / 2;
		G[u].push_back(v);
		G[v].push_back(u);
	}
	
	for (int i = 0; i < n; i++) {
		if (vis[i]) continue ;
		int u = i;
		do {
			vis[u] = 1;
			int nxt = -1;
			for (auto v : G[u])
				if (!vis[v]) nxt = v;
			if (u < n / 2 and nxt >= n / 2) {
				match[u] = nxt - n / 2;
			}
			u = nxt;
		} while (u != -1);
	}
	
	vi uppera(n / 2), upperb(n / 2), lowera(n / 2), lowerb(n / 2);
	vi mj(n / 2, -1), mk(n / 2, -1);
	
	for (int i = 0; i < n / 2; i++) {
		for (int j = 0; j < 2; j++)
			for (int k = 0; k < 2; k++)
				if (a[2 * i + j] == b[2 * match[i] + k])
					mj[i] = j, mk[i] = k;
		if (mj[i] == -1 or mk[i] == -1) {
			mj[i] = mk[i] = 0;
		}
		uppera[i] = a[2 * i + mj[i]];
		lowera[i] = a[2 * i + 1 - mj[i]];
		upperb[match[i]] = b[2 * match[i] + mk[i]];
		lowerb[match[i]] = b[2 * match[i] + 1 - mk[i]];
	}
	
	vi res;
	for (int i = 0; i < n / 2; i++)
		res.push_back(mj[i]);
	res = res + permute(uppera, upperb);
	res = res + permute(lowera, lowerb);
	
	vector<int> proc(n / 2);
	for (int i = 0; i < n / 2; i++)
		proc[match[i]] = mk[i];
	for (int i = 0; i < n / 2; i++)
		res.push_back(proc[i]);
	return res;
}

// Alice
int // returns la
alice(
    /*  in */ const int n,
    /*  in */ const char names[][5],
    /*  in */ const unsigned short numbers[],
    /* out */ bool outputs_alice[]
) {
	vector<int> a(n);
	for (int i = 0; i < n; i++) a[i] = i;
	sort(a.begin(), a.end(), [&](int x, int y) {
		return encode(names[x]) < encode(names[y]);
	});
	int offset = 0;
	for (int i : a) {
		int x = encode(names[i]), y = numbers[i];
		for (int j = 0; j < 19; j++)
			outputs_alice[offset + j] = x >> j & 1;
		for (int j = 0; j < 16; j++)
			outputs_alice[offset + 19 + j] = y >> j & 1;
		offset += 35;
	}
	
	for (int i = n; i < 1024; i++) a.push_back(i);
	vector<int> b;
	for (int i = 0; i < 1024; i++) b.push_back(i);
	auto switches = permute(a, b);
	for (auto i : switches) {
		outputs_alice[offset] = i;
		offset++;
	}
	
    return 35 * n + 10240;
}


// Bob
int // returns lb
bob(
    /*  in */ const int m,
    /*  in */ const char senders[][5],
    /*  in */ const char recipients[][5],
    /* out */ bool outputs_bob[]
) {
	vector<int> a(m);
	for (int i = 0; i < m; i++) a[i] = i;
	sort(a.begin(), a.end(), [&](int x, int y) {
		return encode(senders[x]) < encode(senders[y]);
	});
	int offset = 0;
	for (int i : a) {
		int x = encode(senders[i]), y = encode(recipients[i]);
		for (int j = 0; j < 19; j++)
			outputs_bob[offset + j] = x >> j & 1;
		for (int j = 0; j < 19; j++)
			outputs_bob[offset + 19 + j] = y >> j & 1;
		offset += 38;
	}
    return offset;
}

vector<int> subsegment(vector<int> a, int l, int r) {
	vector<int> b;
	if (l <= r) for (int i = l; i <= r; i++) b.push_back(a[i]);
	else for (int i = r; i >= l; i--) b.push_back(a[i]);
	return b;
}

// Circuit
int // returns l
circuit(
    /*  in */ const int la,
    /*  in */ const int lb,
    /* out */ int operations[],
    /* out */ int operands[][2],
    /* out */ int outputs_circuit[][16]
) {
	int tot = la + lb - 1;
	
	auto getraw = [&](int op, int x, int y) {
		tot++; operations[tot] = op;
		operands[tot][0] = x;
		operands[tot][1] = y;
		return tot;
	};
	
	int zero = getraw(OP_ZERO, 0, 0);
	int one = getraw(OP_ONE, 0, 0);
	
	auto get = [&](int op, int x, int y) {
		int vx = -1; if (x == zero) vx = 0; if (x == one) vx = 1;
		int vy = -1; if (y == zero) vy = 0; if (y == one) vy = 1;
		if (vx != -1 and vy != -1) {
			return (op >> (vx + 2 * vy) & 1) ? one : zero;
		}
		if (vx != -1) {
			int v0 = op >> (vx + 2 * 0) & 1;
			int v1 = op >> (vx + 2 * 1) & 1;
			if (v0 == 0 and v1 == 0) return zero;
			if (v0 == 1 and v1 == 1) return one;
			if (v0 == 0 and v1 == 1) return y;
			if (v0 == 1 and v1 == 0) return getraw(OP_NAND, y, y);
		}
		if (vy != -1) {
			int v0 = op >> (0 + 2 * vy) & 1;
			int v1 = op >> (1 + 2 * vy) & 1;
			if (v0 == 0 and v1 == 0) return zero;
			if (v0 == 1 and v1 == 1) return one;
			if (v0 == 0 and v1 == 1) return x;
			if (v0 == 1 and v1 == 0) return getraw(OP_NAND, x, x);
		}
		if (x == y) {
			int v0 = op >> (0 + 2 * 0) & 1;
			int v1 = op >> (1 + 2 * 1) & 1;
			if (v0 == 0 and v1 == 0) return zero;
			if (v0 == 1 and v1 == 1) return one;
			if (v0 == 0 and v1 == 1) return x;
			if (v0 == 1 and v1 == 0) return getraw(OP_NAND, x, x);
		}
		return getraw(op, x, y);
	};
	
	auto getNOT = [&](int x) {return get(OP_NAND, x, x);};
	auto getAND = [&](int x, int y) {return get(OP_AND, x, y);};
	auto getOR  = [&](int x, int y) {return get(OP_OR,  x, y);};
	auto getXOR = [&](int x, int y) {return get(OP_XOR, x, y);};
	auto getIF  = [&](int x, int y, int z) { // x ? y : z
		if (x == one) return y;
		if (x == zero) return z;
		if (y == z) return y | z;
		return getOR(getAND(x, y), get(OP_LESS, x, z));
	};
	
	// n-bit adder (5n gates)
	auto getADD = [&](vector<int> a, vector<int> b) {
		vector<int> c((int)a.size(), zero);
		int carry = zero;
		for (int i = 0; i < (int)a.size(); i++) {
			int x = getXOR(a[i], b[i]);
			c[i] = getXOR(x, carry);
			carry = getOR(getAND(a[i], b[i]), getAND(x, carry));
		}
		return c;
	};
	
	// n-bit compare, true if a < b, false if b < a (4n gates)
	auto getCMP = [&](vector<int> a, vector<int> b) {
		int res = get(OP_LESS, a[0], b[0]);
		for (int i = 1; i < (int)a.size(); i++)
			res = getOR(get(OP_LESS, a[i], b[i]),
				 getAND(get(OP_EQUAL, a[i], b[i]), res));
		return res;
	};
	
	// n-bit selector, a if true, b if false (3n gates)
	auto getSELECT = [&](int op, vector<int> a, vector<int> b) {
		if (op == one) return a;
		if (op == zero) return b;
		vector<int> c((int)a.size(), zero);
		for (int i = 0; i < (int)a.size(); i++)
			c[i] = getIF(op, a[i], b[i]);
		return c;
	};
	
	// n-bit selector, (a, b) if true, (b, a) if false (4n gates)
	auto getSELECT2 = [&](int op, vector<int> a, vector<int> b) {
		if (op == one) return make_pair(a, b);
		if (op == zero) return make_pair(b, a);
		vector<int> A((int)a.size(), zero);
		vector<int> B((int)a.size(), zero);
		for (int i = 0; i < (int)a.size(); i++) {
			if (a[i] == b[i]) {A[i] = a[i], B[i] = b[i]; continue ;}
			int ne = getXOR(a[i], b[i]);
			A[i] = getXOR(a[i], getAND(op, ne));
			B[i] = getXOR(ne, A[i]);
		}
		return make_pair(B, A);
	};
	
	int n = (la - 10240) / 35;
	int m = lb / 38;
	
	vector<vector<int>> info;
	for (int i = 0; i < n; i++) {
		int offset = 35 * i;
		vector<int> a;
		for (int j = 0; j < 19; j++) a.push_back(offset + j);
		a.push_back(zero);
		for (int j = 0; j < 16; j++) a.push_back(offset + 19 + j);
		for (int j = 0; j < 3; j++) a.push_back(zero);
		info.push_back(a);
	}
	reverse(info.begin(), info.end());
	for (int i = 0; i < m; i++) {
		int offset = la + 38 * i;
		vector<int> b;
		for (int j = 0; j < 19; j++) b.push_back(offset + j);
		b.push_back(one);
		for (int j = 0; j < 19; j++) b.push_back(offset + 19 + j);
		info.push_back(b);
	}
	
	int cl = 19, cr = 19, cl2 = 0, cr2 = 18;
	
	auto CMP = [&](int x, int y) {
		vector<int> sa = subsegment(info[x], cl, cr) + subsegment(info[x], cl2, cr2);
		vector<int> sb = subsegment(info[y], cl, cr) + subsegment(info[y], cl2, cr2);
		int op = getCMP(sa, sb);
		auto nxt = getSELECT2(op, info[x], info[y]);
		info[x] = nxt.first;
		info[y] = nxt.second;
	};
	
	function<void(int, int)> MERGE = [&](int lo, int n) {
		if (n <= 1) return ;
		int m = 1;
		while (m < n) m <<= 1;
		m >>= 1;
		for (int i = lo; i + m < lo + n; i++) CMP(i, i + m);
		MERGE(lo, m);
		MERGE(lo + m, n - m);
	};
	
	auto SORT = [&]() {
		vector<pair<int, int>> network = sorting_network((int)info.size());
	//	cerr << "size: " << (int)network.size() << endl;
	//	cerr << "before: " << tot << endl; int b4 = tot;
		for (auto i : network) CMP(i.first, i.second);
	//	cerr << "after: " << tot << endl;
	//	cerr << "used: " << tot - b4 << endl;
	//	cerr << endl;
	};
	
	// sort by sender
	cl = 19, cr = 19, cl2 = 0, cr2 = 18;
	MERGE(0, n + m);
	
	//	cerr << endl << "after merge: " << tot << endl << endl;
	
	vector<int> val(16, zero);
	for (int i = 0; i < n + m; i++) {
		int type = info[i][19]; // 0: alice, 1: bob
		vector<int> name = getSELECT(type, subsegment(info[i], 20, 38), subsegment(info[i], 0, 18));
		val = getSELECT(type, val, subsegment(info[i], 20, 35));
		
		info[i] = vector<int>(36, zero);
		for (int j = 0; j < 19; j++) info[i][j] = name[j];
		info[i][19] = getNOT(type);
		for (int j = 0; j < 16; j++) info[i][20 + j] = getAND(type, val[j]);
	}
	
	// sort by receiver
	cl = 19, cr = 19, cl2 = 0, cr2 = 18;
	SORT();
	
	vector<int> sum(16, zero);
	for (int i = 0; i < n + m; i++) {
		sum = getADD(sum, subsegment(info[i], 20, 35));
		
		int type = info[i][19]; // 0: bob, 1: alice
		vector<int> a(28);
		for (int j = 0; j < 11; j++)
			a[j] = (i >> j & 1 ? one : zero);
		a[11] = getNOT(type);
		for (int j = 0; j < 16; j++)
			a[12 + j] = sum[j];
		info[i] = a;
		
		for (int j = 0; j < 16; j++)
			sum[j] = getAND(sum[j], a[11]);
	}
	
	// sort by name
	cl = 0, cr = 10, cl2 = 11, cr2 = 11;
	if (m <= 800) SORT();
	else SORT(); // limit exceeded?
	
	int ptr = 35 * n;
	auto read = [&]() {
		int ans = ptr;
		ptr++;
		return ans;
	};
	
	vector<vector<int>> ord;
	for (int i = 0; i < n; i++) {
		vector<int> a;
		for (int j = 0; j < 16; j++)
			a.push_back(info[i][12 + j]);
		ord.push_back(a);
	}
	for (int i = n; i < 1024; i++) {
		ord.push_back(vector<int>(16, zero));
	}
	info = ord;
	
	function<vector<vi>(vector<vi>)> unpermute = [&](vector<vector<int>> a) {
		int n = a.size();
		if (n == 1) return a;
		
		vector<int> mj;
		for (int i = 0; i < n / 2; i++)
			mj.push_back(read());
		
		vector<vector<int>> l(n / 2), r(n / 2);
		for (int i = 0; i < n / 2; i++) {
			auto obj = getSELECT2(mj[i], a[2 * i + 1], a[2 * i]);
			l[i] = obj.first;
			r[i] = obj.second;
		}
		
		l = unpermute(l);
		r = unpermute(r);
		
		vector<int> mk;
		for (int i = 0; i < n / 2; i++)
			mk.push_back(read());
		
		vector<vector<int>> res(n);
		for (int i = 0; i < n / 2; i++) {
			auto obj = getSELECT2(mk[i], r[i], l[i]);
			res[2 * i] = obj.first;
			res[2 * i + 1] = obj.second;
		}
		
		return res;
	};
	
	// sort in alice order
	info = unpermute(info);
	
	for (int i = 0; i < n; i++)
		for (int j = 0; j < 16; j++)
			outputs_circuit[i][j] = info[i][j];
	
	return tot + 1;
}
# Verdict Execution time Memory Grader output
1 Correct 209 ms 5456 KB Correct!
# Verdict Execution time Memory Grader output
1 Correct 209 ms 5456 KB Correct!
2 Correct 223 ms 12468 KB Correct!
# Verdict Execution time Memory Grader output
1 Correct 209 ms 5456 KB Correct!
2 Correct 223 ms 12468 KB Correct!
3 Correct 726 ms 211984 KB Correct!
4 Correct 728 ms 212692 KB Correct!
# Verdict Execution time Memory Grader output
1 Correct 247 ms 15620 KB Correct!
2 Correct 490 ms 116056 KB Correct!
3 Correct 558 ms 148740 KB Correct!
# Verdict Execution time Memory Grader output
1 Correct 247 ms 15620 KB Correct!
2 Correct 490 ms 116056 KB Correct!
3 Correct 558 ms 148740 KB Correct!
4 Correct 510 ms 116432 KB Correct!
5 Correct 564 ms 148684 KB Correct!
# Verdict Execution time Memory Grader output
1 Correct 247 ms 15620 KB Correct!
2 Correct 490 ms 116056 KB Correct!
3 Correct 558 ms 148740 KB Correct!
4 Correct 510 ms 116432 KB Correct!
5 Correct 564 ms 148684 KB Correct!
6 Correct 437 ms 89980 KB Correct!
7 Correct 711 ms 197456 KB Correct!
# Verdict Execution time Memory Grader output
1 Correct 1237 ms 414308 KB Correct!
2 Correct 1229 ms 415012 KB Correct!
# Verdict Execution time Memory Grader output
1 Correct 1237 ms 414308 KB Correct!
2 Correct 1229 ms 415012 KB Correct!
3 Correct 1198 ms 386584 KB Correct!
4 Correct 1278 ms 414660 KB Correct!
# Verdict Execution time Memory Grader output
1 Correct 209 ms 5456 KB Correct!
2 Correct 223 ms 12468 KB Correct!
3 Correct 726 ms 211984 KB Correct!
4 Correct 728 ms 212692 KB Correct!
5 Correct 247 ms 15620 KB Correct!
6 Correct 490 ms 116056 KB Correct!
7 Correct 558 ms 148740 KB Correct!
8 Correct 510 ms 116432 KB Correct!
9 Correct 564 ms 148684 KB Correct!
10 Correct 437 ms 89980 KB Correct!
11 Correct 711 ms 197456 KB Correct!
12 Correct 1237 ms 414308 KB Correct!
13 Correct 1229 ms 415012 KB Correct!
14 Correct 1198 ms 386584 KB Correct!
15 Correct 1278 ms 414660 KB Correct!
16 Correct 1272 ms 421012 KB Correct!
17 Correct 1271 ms 421652 KB Correct!