Submission #829215

# Submission time Handle Problem Language Result Execution time Memory
829215 2023-08-18T06:40:51 Z ngrace Bit Shift Registers (IOI21_registers) C++17
100 / 100
1 ms 468 KB
#include "registers.h"

//comparison (max case n=100, k=10):
//s=0: basic_min - 2972, parallel_min - 238, kmin - 94
//n=2, k=2: parallel_min - 22, kmin - 20
//s=1, n=10: basic_sort - 1424
//s=1: parallel_sort - 3200, odd_even_merge_sort - 1666 (unoptimised, straightforward implementation), odd_even_merge_sort - 1163 (optimised)
//s=1: double_parallel_sort - 1066, optimised_double_parallel_sort - 982

//overall: s=0 is 94 instructions, s=1 is 982 instructions

#include <bits/stdc++.h>
using namespace std;
#define v vector

const int b=2000;
int n, k;

void push_down(int reg, int num_down, int current_done = 1){
	//all 1 bits in reg will be propogated num_down to the right (inclusive), i.e. 001000101 becomes 001110111 if num_down is 3
	//2log(num_down) instructions
	while(current_done<num_down){
		int mov = min(2*current_done, num_down) - current_done;
		append_right(99, reg, mov);
		append_or(reg, 99, reg);
		current_done+=mov;
	}
}

void push_up(int reg, int num_up){
	//same as push_down but up
	int current_done = 1;
	while(current_done<num_up){
		int mov = min(2*current_done, num_up) - current_done;
		append_left(99, reg, mov);
		append_or(reg, 99, reg);
		current_done+=mov;
	}
}

void basic_swap(int a, int b, int mask, int one){
	//swaps a and b if num from a > num from b (assume a<b) - some optimisations to reduce no. instructions if a is 0
	//2 + 5 + 2logk + 2 + 2log(2k) + 5 = 14 + 6logk instructions
	//if a=0 then save on 2 instructions
	//means if k==2 then 18 instructions exactly (+2 to initialise mask elsewhere)
	//worst case (k=10): 38 instructions

	//have num from a in reg 1, and num from b in reg 2
	if(a!=0) append_right(1, 0, a*k);
	append_right(2, 0, b*k);
	
	//xor a and b
	if(a==0) append_xor(3, 0, 2);
	else append_xor(3, 1, 2);
	append_and(3, 3, mask);
	//propogate most significant bit (all 1 to right), i.e. 0010100 becomes 0011111
	append_move(4, 3);
	push_down(4, k);
	//isolate most significant bit i.e. 0011111 -> 0010000
	append_add(4, 4, one);
	append_right(4, 4, 1);

	//turn isolated most significant bit into full mask for if a needs to be swapped with b
	//	and of 4 (hi bit) and 1 (num from a) gives 1 if a was the larger and 0 otherwise, then propogate to whole k block
	if(a==0) append_and(5, 4, 0);
	else append_and(5, 4, 1);
	append_left(5, 5, k);
	push_down(5, 2*k);

	//swap a and b if necessary (note: a xor (a xor b) = b, as xor is associative)
	append_and(6, 3, 5);
	if(a==0){
		append_xor(0, 0, 6);
	}
	else{
		append_left(7, 6, a*k);
		append_xor(0, 0, 7);
	}
	append_left(7, 6, b*k);
	append_xor(0, 0, 7);
}

void basic_min(){
	v<bool> mask(b, false);
	v<bool> one(b, false);
	for(int i=0; i<k; i++) mask[i]=true;
	one[0]=true;
	append_store(20, mask);
	append_store(21, one);
	for(int i=1; i<n; i++){
		basic_swap(0, i, 20, 21);
	}
}

void basic_sort(){
	v<bool> mask(b, false);
	v<bool> one(b, false);
	for(int i=0; i<k; i++) mask[i]=true;
	one[0]=true;
	append_store(20, mask);
	append_store(21, one);
	for(int i=n-1; i>0; i--){
		for(int j=0; j<i; j++){
			basic_swap(j, j+1, 20, 21);
		}
	}
}

void parallel_swap(int dist){
	//swaps (if necessary) a set of index disjoint pairs at once.
	//setup: assume register 0-3 are:
	//0: the input
	//1: the second of each pair, in the same position (second of each pair is dist to the right of first)
	append_right(1, 0, dist*k);//assumption that want to sort indices with the index right above it
	//2: mask, k 1's in blocks that indicate that section of 0 should be sorted with same section in 1, assumption that no adjacent blocks both 1
	//3: ones, similar to 1 but only first bit in each block is 1

	append_xor(4, 0, 1);
	append_and(4, 4, 2);

	//isolate most significant bit of each block (see basic_swap()):
	/* old
	append_move(5, 4);
	push_down(5, k);
	*/
	/* new*/
	append_right(5, 4, 1);
	append_or(5, 4, 5);
	push_down(5, k, 2);
	/*^avoids useless move*/
	append_add(5, 5, 3);
	append_right(5, 5, 1);
	append_and(5, 5, 2);

	//turn isolated most significant bit into mask (1 in block if should swap, 0 if not)
	append_and(5, 5, 0);
	append_left(5, 5, k);
	push_down(5, 2*k);
	//append_and(5, 5, 2);//unnecesary as below 5 is anded with 4, which was allready anded with 2

	//get key (xor & mask), since xor is associative this xored with 0 and 1 will perform swap
	append_and(5, 5, 4);
	append_xor(0, 0, 5);
	append_left(5, 5, dist*k);
	append_xor(0, 0, 5);
}

void parallel_min(){
	int every = 2;
	while(every/2<n){
		v<bool> mask(b, false);
		v<bool> ones(b, false);
		for(int i=0; i+every/2<n; i+=every){
			ones[i*k]=true;
			for(int j=0; j<k; j++) mask[i*k+j]=true;
		}
		append_store(2, mask);
		append_store(3, ones);
		
		parallel_swap(every/2);

		every*=2;
	}
}

void parallel_sort(){
	//O(n)
	bool even=true;
	for(int i=0; i<n; i++){
		v<bool> mask(b, false);
		v<bool> ones(b, false);
		for(int i=(even ? 0 : 1); i+1<n; i+=2){
			ones[i*k]=true;
			for(int j=0; j<k; j++) mask[i*k+j]=true;
		}
		even = !even;

		append_store(2, mask);
		append_store(3, ones);
		parallel_swap(1);
	}
}

// Unoptimised version (1666 instructions)
void simple_odd_even_sort(){
	//https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort
	v<bool> block(b, false);
	for(int i=n*k; i<b; i++) block[i]=true;
	append_store(1, block);
	append_or(0, 0, 1);
	int tmp=1;
	while(tmp<n) tmp<<=1;
	n=tmp;
	for(int p=1; p<n; p<<=1){
		for(int l=p; l>0; l>>=1){
			bool doEvens=false, doOdds=false;
			v<bool> even(b, false), evenones(b, false);
			v<bool> odd(b, false), oddones(b, false);
			for(int j=l%p; j<n-l; j+=2*l){
				for(int i=0; i <= min(l-1, n-j-l-1); i++){
					if ((i+j) / (p*2) == (i+j+l) / (p*2)){
						if((i+j)%2==0){
							doEvens=true;
							evenones[k*(i+j)] = true;
							for(int t=0; t<k; t++) even[k*(i+j)+t] = true;
						}
						else{
							doOdds=true;
							oddones[k*(i+j)] = true;
							for(int t=0; t<k; t++) odd[k*(i+j)+t] = true;
						}
						// v<bool> mask(b, false);
						// v<bool> one(b, false);
						// for(int i=0; i<k; i++) mask[i]=true;
						// one[0]=true;
						// append_store(20, mask);
						// append_store(21, one);
						// basic_swap(i+j, i+j+l, 20, 21);
						//cout<<i+j<<" "<<i+j+l<<endl;
					}
				}
			}
			//cout<<"end layer"<<endl;
			//continue;
			if(doEvens){
				append_store(2, even);
				append_store(3, evenones);
				parallel_swap(l);
			}
			if(doOdds){
				append_store(2, odd);
				append_store(3, oddones);
				parallel_swap(l);
			}
		}
	}
}

void optimised_odd_even_sort(){//1163 instructions
	//https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort
	int depth = 0;
	v<bool> blockMask(b, false);
	for(int i=0; i<k*n; i++) blockMask[i]=true;
	append_store(39, blockMask);
	for(int p=1; p<n; p<<=1){
		for(int l=p; l>0; l>>=1){
			v<bool> oddMask(b, false);
			v<bool> mask(b, false), ones(b, false);
			v<bool> simplemask(b, false),simpleones(b, false);
			for(int j=l%p; j<n-l; j+=2*l){
				for(int i=0; i <= min(l-1, n-j-l-1); i++){
					if ((i+j) / (p*2) == (i+j+l) / (p*2) && i+j+l<n){
						simpleones[k*(i+j)] = true;
						for(int t=0; t<k; t++) simplemask[k*(i+j)+t] = true;
						
						if((i+j)%2==0){
							ones[k*(i+j)] = true;
							for(int t=0; t<k; t++) mask[k*(i+j)+t] = true;
						}
						else{
							ones[k*(i+j) + k*n] = true;
							for(int t=0; t<k; t++){
								oddMask[k*(i+j) + t] = true;
								oddMask[k*(i+j+l) + t] = true;
								mask[k*(i+j)+t + k*n] = true;
							}
						}
						//cout<<i+j<<" "<<i+j+l<<endl;
					}
				}
			}
			//cout<<"end layer"<<endl;
			depth+=1;

			bool needDouble = false;
			for(int i=0; i<n; i++){
				if(simplemask[i*k] && simplemask[i*k+k]) needDouble = true;
			}

			if(needDouble){
				append_store(40, oddMask);
				append_and(41, 0, 40);
				append_left(41, 41, n*k);
				append_or(0, 41, 0);
				append_store(2, mask);
				append_store(3, ones);
				parallel_swap(l);
				append_right(1, 0, n*k);
				append_and(1, 1, 40);
				append_not(40, 40);
				append_and(0, 0, 40);
				append_or(0, 0, 1);
				append_and(0, 0, 39);
			}
			else{
				append_store(2, simplemask);
				append_store(3, simpleones);
				parallel_swap(l);
			}
		}
	}

	//cout<<depth<<endl;
}

void double_parallel_swap(int dist){
	//swaps (if necessary) a set of index disjoint pairs at once.
	//setup: assume register 0-3 are:
	//0: the input
	//1: the second of each pair, in the same position (second of each pair is dist to the right of first)
	append_right(1, 0, dist*k);
	//2: mask, k 1's in blocks that indicate that section of 0 should be sorted with same section in 1, assumption that no adjacent blocks both 1
	//3: ones, similar to 1 but only first bit in each block is 1

	append_xor(4, 0, 1);
	append_and(4, 4, 2);

	//isolate most significant bit of each block (see basic_swap()):
	append_right(5, 4, 1);
	append_or(5, 4, 5);
	push_down(5, k, 2);
	append_add(5, 5, 3);
	append_right(5, 5, 1);
	append_and(5, 5, 2);

	//turn isolated most significant bit into mask (1 in block if should swap, 0 if not)
	append_and(5, 5, 0);
	append_left(5, 5, k);
	push_down(5, 2*k);
	//append_and(5, 5, 2);
	

	//get key (xor & mask), since xor is associative this xored with 0 and 1 will perform swap
	append_and(5, 5, 4);
	append_right(6, 5, n*k);
	append_left(7, 5, n*k);
	append_or(5, 5, 6);
	append_or(5, 5, 7);
	append_xor(0, 0, 5);
	append_left(5, 5, dist*k);
	append_xor(0, 0, 5);
}
void double_odd_even_sort(){//1066 instructions
	//https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort
	int depth = 0;
	append_left(1, 0, n*k);
	append_or(0, 0, 1);
	for(int p=1; p<n; p<<=1){
		for(int l=p; l>0; l>>=1){
			v<bool> mask(b, false), ones(b, false);
			for(int j=l%p; j<n-l; j+=2*l){
				for(int i=0; i <= min(l-1, n-j-l-1); i++){
					if ((i+j) / (p*2) == (i+j+l) / (p*2) && i+j+l<n){
						if((i+j)%2==0){
							ones[k*(i+j)] = true;
							for(int t=0; t<k; t++) mask[k*(i+j)+t] = true;
						}
						else{
							ones[k*(i+j) + k*n] = true;
							for(int t=0; t<k; t++) mask[k*(i+j)+t + k*n] = true;
						}
						//cout<<i+j<<" "<<i+j+l<<endl;
					}
				}
			}
			//cout<<"end layer"<<endl;
			depth+=1;

			
			append_store(2, mask);
			append_store(3, ones);
			double_parallel_swap(l);
		}
	}

	//cout<<depth<<endl;
}
void optimised_double_odd_even_sort(){//982 instructions
	//https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort
	int depth = 0;
	append_left(1, 0, n*k);
	append_or(0, 0, 1);
	for(int p=1; p<n; p<<=1){
		for(int l=p; l>0; l>>=1){
			v<bool> mask(b, false), ones(b, false);
			v<bool> simplemask(b, false),simpleones(b, false);
			for(int j=l%p; j<n-l; j+=2*l){
				for(int i=0; i <= min(l-1, n-j-l-1); i++){
					if ((i+j) / (p*2) == (i+j+l) / (p*2) && i+j+l<n){
						simpleones[k*(i+j)] = true;
						simpleones[k*(i+j) + k*n] = true;
						for(int t=0; t<k; t++) simplemask[k*(i+j)+t] = true;
						for(int t=0; t<k; t++) simplemask[k*(i+j)+t + k*n] = true;

						if((i+j)%2==0){
							ones[k*(i+j)] = true;
							for(int t=0; t<k; t++) mask[k*(i+j)+t] = true;
						}
						else{
							ones[k*(i+j) + k*n] = true;
							for(int t=0; t<k; t++) mask[k*(i+j)+t + k*n] = true;
						}
						//cout<<i+j<<" "<<i+j+l<<endl;
					}
				}
			}
			//cout<<"end layer"<<endl;
			depth+=1;

			bool needDouble = false;
			for(int i=0; i<n; i++){
				if(simplemask[i*k] && simplemask[i*k+k]) needDouble = true;
			}
			
			if(needDouble){
				append_store(2, mask);
				append_store(3, ones);
				double_parallel_swap(l);
			}
			else{
				append_store(2, simplemask);
				append_store(3, simpleones);
				parallel_swap(l);
			}
		}
	}

	//cout<<depth<<endl;
}

void kmin(){
	v<bool> mask(b, false), maskall(b, true);
	for(int i=0; i<n; i++) mask[i*k+k-1] = true;
	append_store(1, mask);
	append_store(10, maskall);
	for(int i=0; i<k; i++){
		append_and(2, 0, 1);//1 in spots where greater
		append_xor(3, 1, 2);//1 in spots where lesser - unless all were 1, in which case we need to invert all
		append_add(9, 3, 10);
		append_right(9, 9, n*k);//mask all 1 if need to keep everything, 0 otherwise
		append_and(5, 1, 9);
		append_or(1, 3, 5);

		/* old approach to get mask all 1 if need to keep everything
		append_move(4, 3);
		push_down(4, n*k);//BAD
		push_up(4, n*k);//BAD
		append_not(4, 4);//mask all 1 if need to keep everything, 0 otherwise
		*/

		if(i!=k-1) append_right(1, 1, 1);
	}
	push_up(1, k);
	append_and(0, 0, 1);
	//push down by unit blocks of k
	int current_done = 1;
	while(current_done<n){
		int mov = min(2*current_done, n) - current_done;
		append_right(99, 0, mov*k);
		append_or(0, 99, 0);
		current_done+=mov;
	}
}

void construct_instructions(int S, int N, int K, int Q) {
	n=N, k=K;
	if(S==0) kmin();
	else if(S==1) optimised_double_odd_even_sort();
}
# Verdict Execution time Memory Grader output
1 Correct 1 ms 212 KB Output is correct
# Verdict Execution time Memory Grader output
1 Correct 1 ms 212 KB Output is correct
# Verdict Execution time Memory Grader output
1 Correct 1 ms 296 KB Output is correct
2 Correct 1 ms 212 KB Output is correct
3 Correct 0 ms 212 KB Output is correct
4 Correct 1 ms 284 KB Output is correct
# Verdict Execution time Memory Grader output
1 Correct 0 ms 212 KB Output is correct
2 Correct 1 ms 212 KB Output is correct
3 Correct 0 ms 212 KB Output is correct
4 Correct 0 ms 300 KB Output is correct
5 Correct 0 ms 212 KB Output is correct
6 Correct 0 ms 296 KB Output is correct
# Verdict Execution time Memory Grader output
1 Correct 1 ms 340 KB Output is correct
2 Correct 1 ms 340 KB Output is correct
# Verdict Execution time Memory Grader output
1 Correct 1 ms 340 KB Output is correct
2 Correct 1 ms 340 KB Output is correct
3 Correct 1 ms 468 KB Output is correct
4 Correct 1 ms 468 KB Output is correct
5 Correct 1 ms 468 KB Output is correct
6 Correct 1 ms 428 KB Output is correct
7 Correct 1 ms 468 KB Output is correct