Submission #438300

# Submission time Handle Problem Language Result Execution time Memory
438300 2021-06-27T20:31:57 Z CyanForces Bit Shift Registers (IOI21_registers) C++17
100 / 100
6 ms 892 KB
#include <bits/stdc++.h>
using namespace std;

#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define trav(a, x) for(auto& a : x)
#define all(x) begin(x), end(x)
#define sz(x) (int)(x).size()
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;

#include "registers.h"

namespace SOLUTION_S0 {
  int s, n, k, q;
  int ops = 0;
  vector<bool> negOne(2000);
  vector<bool> signBit(2000);
  vector<bool> oneNum(2000);


  const int negOneIndex = 99;
  const int signBitIndex = 98;
  const int oneNumIndex = 97;
  const int workRegister = 96;

  int bit_len;

  void spread() {
    //append_print(0);
    //
    vector<bool> takeFirst(2000), takeSecond(2000);
    rep(i,0,2000) {
      int temp = (i / (k)) % 2;
      if (temp == 0) takeFirst[i] = 1;
      else takeSecond[i] = 1;
    }
    append_store(50, takeFirst);
    ops++;
    append_store(51, takeSecond);


    append_and(1, 0, 50);
    append_and(0, 0, 51);

    append_left(0, 0, (n-(n%2==0))*k);
    append_xor(1,0,1);


    //for(int i = 0; i < n; i++) {
    //  append_and(workRegister, 0, oneNumIndex);
    //  ops++;
    //  if (i != n-1) {
    //    append_left(oneNumIndex, oneNumIndex, bit_len);
    //    ops++;
    //    append_left(0, 0, k);
    //    ops++;
    //  }
    //  append_or(1, 1, workRegister);
    //  ops++;
    //}

    append_print(1);
  }

  void divide() {
    int leftNumbers = (n+1) / 2;
    //append_move(2, 1);
    ops++;
    append_right(2, 1, leftNumbers * (bit_len));
    ops++;
    if(n % 2 == 1) {
      vector<bool> tempExtra(2000);
      rep(i,0,2000) {
        tempExtra[i] = (n-leftNumbers)*(bit_len) <= i && i < (n-leftNumbers+1)*(bit_len)-1;
      }
      append_store(3, tempExtra);
      ops++;
      append_or(2, 2, 3);
      ops++;
    }
  }

  void spreadSigns(int d) {
    ops++;
    append_right(workRegister, 0, d);
    ops++;
    append_or(0, 0, workRegister);
    ops++;
  }

  void combine() {
    ops++;
    append_add(0, negOneIndex, 2); // b-1
    ops++;
    append_not(0, 0); // !(b-1)
    ops++;
    append_add(0, 0, 1); // a-b
    ops++;

    append_and(0, 0, signBitIndex);
    ops++;
    append_right(0,0,k);
    //rep(_,0,k+1) spreadSigns(1);
    int curr = 1;
    while(curr*2 < k) {
      spreadSigns(curr);
      curr *= 2;
    }
    if (k != curr) spreadSigns(k - curr);

    append_not(workRegister, 0);
    ops++;
    // registry 0 contains zeros if a >= b, else 1
    // => registry 0 contains a >= b
    // => workRegister contains a < b
    append_and(0, 0, 2);
    ops++;
    append_and(workRegister, workRegister, 1);
    ops++;

    //append_xor(1, 1, 2);
    // register 0 = xor(a, b)
    ops++;
    append_xor(1, 0, workRegister);
    ops++;
    // register 2 = max(a, b)
    //append_xor(1, 1, 2);
    ops++;
  }

  void print() {
    //rep(i,0,3) append_print(i);
  }

  void findMin() {
    spread();
    //return;
    while(n > 1) {
      //print();
      divide();
      //print();
      combine();
      //print();
      n = (n+1) / 2;
    }
    //print();
    append_move(0, 1);
    ops++;
  }

  void unspread() {
    vector<bool> empty(2000);
    append_store(0, empty);
    int shiftAmount = n-1;

    for(int i = n-1; i >= 0; i--) {
      cerr << "i = " << i << endl;
      cerr << "ops = " << ops << endl;
      append_and(workRegister, 1, oneNumIndex);
      ops++;
      append_right(workRegister, workRegister, 2*i);
      ops++;
      append_right(oneNumIndex, oneNumIndex, bit_len);
      ops++;
      if (shiftAmount > 0) {
        append_right(workRegister, workRegister, shiftAmount * k);
      }
      else {
        append_left(workRegister, workRegister, -shiftAmount * k);
      }
      shiftAmount -= 2;

      append_or(0, 0, workRegister);
      ops++;
    }

    //append_store(oneNumIndex, oneNum);
    //append_store(1, empty);
    //spread();
    //print();
  }

  void stupidSort() {
    spread();

    vector<bool> tempExtra(2000);
    rep(i,0,2000) {
      tempExtra[i] = i < k;
    }
    append_store(75, tempExtra);
    ops++;

    vector<bool> takeFirst(2000), takeSecond(2000);
    rep(i,0,2000) {
      int temp = (i / (bit_len)) % 2;
      if (temp == 0) takeFirst[i] = 1;
      else takeSecond[i] = 1;
    }
    append_store(50, takeFirst);
    ops++;
    append_store(51, takeSecond);
    ops++;
    cerr << "n = " << n << endl;
    print();
    rep(i,0,n+3) {
      print();
      cerr << "i = " << i << ", ops = " << ops << endl;
      append_move(2, 1);
      ops++;
      append_left(2, 2, bit_len);
      ops++;
      append_xor(2, 2, 75);
      ops++;
      combine();
      print();

      cerr << "i = " << i << ", ops = " << ops << endl;
      if (i % 2 == 0) {
        append_and(1, 1, 50);
        ops++;
        append_and(2, 2, 50);
        ops++;
        append_right(2, 2, bit_len);
        ops++;
      }
      else {
        append_and(1, 1, 51);
        ops++;
        append_and(2, 2, 51);
        ops++;
        append_right(2, 2, bit_len);
        ops++;
      }
      print();
      append_or(1, 1, 2);
      ops++;
      append_print(52);
    }

    unspread();
  }

  void construct_instructions(int _s, int _n, int _k, int _q) {
    s = _s;
    n = _n;
    k = _k;
    q = _q;
    bit_len = 2*k;
    int a = 0;
    int b = bit_len;
    while(b < 2000) {
      for(int j = a; j < b; j++) negOne[j] = j != b-1;
      a = b;
      b += bit_len;
    }
    for(int i = 0; i < 2000; i++) {
      signBit[i] = i % (bit_len) == bit_len-1;
      oneNum[i] = i < k;
    }
    append_store(negOneIndex, negOne);
    ops++;
    append_store(signBitIndex, signBit);
    ops++;
    //append_store(oneNumIndex, oneNum);
    ops++;

    if (s == 0) {
      findMin();
    }
    else {
      stupidSort();
    }
    //assert(ops <= q);
  }
}

namespace SOLUTION_S1 {
  int s, n, k, q;
  int ops = 0;
  vector<bool> negOne(2000);
  vector<bool> signBit(2000);
  vector<bool> oneNum(2000);

  const int negOneIndex = 99;
  const int signBitIndex = 98;
  const int oneNumIndex = 97;
  const int workRegister = 96;

  void spread() {
    for(int i = 0; i < n; i++) {
      append_and(workRegister, 0, oneNumIndex);
      ops++;
      if (i != n-1) {
        append_left(oneNumIndex, oneNumIndex, k+2);
        ops++;
        append_left(0, 0, 2);
        ops++;
      }
      append_or(1, 1, workRegister);
      ops++;
    }
  }

  void divide() {
    int leftNumbers = (n+1) / 2;
    append_move(2, 1);
    ops++;
    append_right(2, 2, leftNumbers * (k+2));
    ops++;
    vector<bool> tempExtra(2000);
    rep(i,0,2000) {
      tempExtra[i] = (n-leftNumbers)*(k+2) <= i && i < (n-leftNumbers+1)*(k+2)-1;
    }
    append_store(3, tempExtra);
    ops++;
    append_or(2, 2, 3);
    ops++;
  }

  void spreadSigns(int d) {
    append_move(workRegister, 0);
    ops++;
    append_right(workRegister, workRegister, d);
    ops++;
    append_or(0, 0, workRegister);
    ops++;
  }

  void combine() {
    append_store(0, negOne);
    ops++;
    append_add(0, 0, 2); // b-1
    ops++;
    append_not(0, 0); // !(b-1)
    ops++;
    append_add(0, 0, 1); // a-b
    ops++;

    append_and(0, 0, signBitIndex);
    ops++;
    //rep(_,0,k+1) spreadSigns(1);
    int curr = 1;
    while(curr*2 < k+2) {
      spreadSigns(curr);
      curr *= 2;
    }
    if (k+2 != curr)
      spreadSigns(k+2 - curr);

    append_not(workRegister, 0);
    ops++;
    // registry 0 contains zeros if a >= b, else 1
    // => registry 0 contains a >= b
    // => workRegister contains a < b
    append_and(0, 0, 1);
    ops++;
    append_and(workRegister, workRegister, 2);
    ops++;

    append_xor(1, 1, 2);
    // register 0 = xor(a, b)
    ops++;
    append_xor(2, 0, workRegister);
    ops++;
    // register 2 = max(a, b)
    append_xor(1, 1, 2);
    ops++;
  }

  void print() {
    rep(i,0,3) append_print(i);
  }

  void findMin() {
    spread();
    while(n > 1) {
      //print();
      divide();
      //print();
      combine();
      //print();
      n = (n+1) / 2;
    }
    //print();
    append_move(0, 1);
    ops++;
  }

  void unspread() {
    vector<bool> empty(2000);
    append_store(0, empty);
    int shiftAmount = n-1;

    for(int i = n-1; i >= 0; i--) {
      cerr << "i = " << i << endl;
      cerr << "ops = " << ops << endl;
      append_and(workRegister, 1, oneNumIndex);
      ops++;
      append_right(workRegister, workRegister, 2*i);
      ops++;
      append_right(oneNumIndex, oneNumIndex, k+2);
      ops++;
      if (shiftAmount > 0) {
        append_right(workRegister, workRegister, shiftAmount * k);
      }
      else {
        append_left(workRegister, workRegister, -shiftAmount * k);
      }
      shiftAmount -= 2;

      append_or(0, 0, workRegister);
      ops++;
    }

    //append_store(oneNumIndex, oneNum);
    //append_store(1, empty);
    //spread();
    //print();
  }

  void stupidSort() {
    spread();

    vector<bool> tempExtra(2000);
    rep(i,0,2000) {
      tempExtra[i] = i < k;
    }
    append_store(75, tempExtra);
    ops++;

    vector<bool> takeFirst(2000), takeSecond(2000);
    rep(i,0,2000) {
      int temp = (i / (k+2)) % 2;
      if (temp == 0) takeFirst[i] = 1;
      else takeSecond[i] = 1;
    }
    append_store(50, takeFirst);
    ops++;
    append_store(51, takeSecond);
    ops++;
    cerr << "n = " << n << endl;
    print();
    rep(i,0,n+3) {
      print();
      cerr << "i = " << i << ", ops = " << ops << endl;
      append_move(2, 1);
      ops++;
      append_left(2, 2, k+2);
      ops++;
      append_xor(2, 2, 75);
      ops++;
      combine();
      print();

      cerr << "i = " << i << ", ops = " << ops << endl;
      if (i % 2 == 0) {
        append_and(1, 1, 50);
        ops++;
        append_and(2, 2, 50);
        ops++;
        append_right(2, 2, k+2);
        ops++;
      }
      else {
        append_and(1, 1, 51);
        ops++;
        append_and(2, 2, 51);
        ops++;
        append_right(2, 2, k+2);
        ops++;
      }
      print();
      append_or(1, 1, 2);
      ops++;
      append_print(52);
    }

    unspread();
  }

  void construct_instructions(int _s, int _n, int _k, int _q) {
    s = _s;
    n = _n;
    k = _k;
    q = _q;
    int a = 0;
    int b = k+2;
    while(b < 2000) {
      for(int j = a; j < b; j++) negOne[j] = j != b-1;
      a = b;
      b += k+2;
    }
    for(int i = 0; i < 2000; i++) {
      signBit[i] = i % (k+2) == k+1;
      oneNum[i] = i < k;
    }
    append_store(negOneIndex, negOne);
    ops++;
    append_store(signBitIndex, signBit);
    ops++;
    append_store(oneNumIndex, oneNum);
    ops++;

    if (s == 0) {
      findMin();
    }
    else {
      stupidSort();
    }
    assert(ops <= q);
  }

}

namespace SOLUTION_N2 {
  int s, n, k, q;
  int ops = 0;
  vector<bool> negOne(2000);
  vector<bool> signBit(2000);
  vector<bool> oneNum(2000);

  const int negOneIndex = 99;
  const int signBitIndex = 98;
  const int oneNumIndex = 97;
  const int workRegister = 96;

  void spread() {
    append_and(1, 0, oneNumIndex);
    append_right(0, 0, k);
    append_and(2, 0, oneNumIndex);
    return;
    for(int i = 0; i < n; i++) {
      append_and(workRegister, 0, oneNumIndex);
      if (i != n-1) {
        append_left(oneNumIndex, oneNumIndex, k+2);
        append_left(0, 0, 2);
      }
      append_or(1, 1, workRegister);
    }
  }

  void divide() {
    int leftNumbers = (n+1) / 2;
    append_move(2, 1);
    append_right(2, 2, leftNumbers * (k+2));
    if (n % 2 != 0) {
      vector<bool> tempExtra(2000);
      rep(i,0,2000) {
        tempExtra[i] = (n-leftNumbers)*(k+2) <= i && i < (n-leftNumbers+1)*(k+2)-1;
      }
      append_store(3, tempExtra);
      append_or(2, 2, 3);
    }
  }

  void spreadSigns(int d) {
    append_right(workRegister, 0, d);
    append_or(0, 0, workRegister);
  }
  //(a & (a>=b))  ^ (b & (a<b))
  void combine() {
    append_add(0, negOneIndex, 2); // b-1
    append_not(0, 0); // !(b-1)
    append_add(0, 0, 1); // a-b

    append_and(0, 0, signBitIndex);
    //rep(_,0,k+1) spreadSigns(1);
    int curr = 1;
    while(curr*2 < k+2) {
      spreadSigns(curr);
      curr *= 2;
    }
    if (k+2 != curr)
      spreadSigns(k+2 - curr);

    append_not(workRegister, 0);
    // registry 0 contains zeros if a >= b, else 1
    // => registry 0 contains a >= b
    // => workRegister contains a < b
    append_and(0, 0, 2);
    append_and(workRegister, workRegister, 1);

    //append_xor(1, 1, 2);
    append_xor(0, 0, workRegister);
    //append_xor(1, 1, 2);
  }

  void print() {
    rep(i,0,3) append_print(i);
  }

  void findMin() {
    spread();
    while(n > 1) {
      print();
      //divide();
      print();
      combine();
      print();
      n = (n+1) / 2;
    }
    print();
    //append_move(0, 1);
  }

  void construct_instructions(int _s, int _n, int _k, int _q) {
    s = _s;
    n = _n;
    k = _k;
    q = _q;
    int a = 0;
    int b = k+2;
    while(b < 2000) {
      for(int j = a; j < b; j++) negOne[j] = j != b-1;
      a = b;
      b += k+2;
    }
    for(int i = 0; i < 2000; i++) {
      signBit[i] = i % (k+2) == k+1;
      oneNum[i] = i < k;
    }
    append_store(negOneIndex, negOne);
    append_store(signBitIndex, signBit);
    append_store(oneNumIndex, oneNum);

    if (s == 0) {
      findMin();
    }
    else {
      assert(false);
    }
    assert(ops <= q);
  }

}

void construct_instructions(int s, int n, int k, int q) {
  if(s == 1) SOLUTION_S1::construct_instructions(s,n,k,q);
  else if(n == 2) SOLUTION_N2::construct_instructions(s,n,k,q);
  else SOLUTION_S0::construct_instructions(s,n,k,q);
}
# Verdict Execution time Memory Grader output
1 Correct 1 ms 204 KB Output is correct
# Verdict Execution time Memory Grader output
1 Correct 1 ms 204 KB Output is correct
# Verdict Execution time Memory Grader output
1 Correct 1 ms 204 KB Output is correct
2 Correct 1 ms 204 KB Output is correct
3 Correct 1 ms 332 KB Output is correct
4 Correct 1 ms 204 KB Output is correct
# Verdict Execution time Memory Grader output
1 Correct 1 ms 332 KB Output is correct
2 Correct 1 ms 204 KB Output is correct
3 Correct 1 ms 204 KB Output is correct
4 Correct 1 ms 332 KB Output is correct
5 Correct 1 ms 204 KB Output is correct
6 Correct 1 ms 204 KB Output is correct
# Verdict Execution time Memory Grader output
1 Correct 2 ms 332 KB Output is correct
2 Correct 1 ms 332 KB Output is correct
# Verdict Execution time Memory Grader output
1 Correct 2 ms 332 KB Output is correct
2 Correct 1 ms 332 KB Output is correct
3 Correct 6 ms 892 KB Output is correct
4 Correct 5 ms 768 KB Output is correct
5 Correct 4 ms 768 KB Output is correct
6 Correct 4 ms 768 KB Output is correct
7 Correct 6 ms 640 KB Output is correct