답안 #769968

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
769968 2023-06-30T15:23:15 Z marvinthang 앨리스, 밥, 서킷 (APIO23_abc) C++17
100 / 100
494 ms 245020 KB
#include "abc.h"
#include <bits/stdc++.h>
 
using namespace std;
 
#define                  fi  first
#define                  se  second
#define                left  ___left
#define               right  ___right
#define                TIME  (1.0 * clock() / CLOCKS_PER_SEC)
#define             MASK(i)  (1LL << (i))
#define           BIT(x, i)  ((x) >> (i) & 1)
#define  __builtin_popcount  __builtin_popcountll
#define              ALL(v)  (v).begin(), (v).end()
#define           REP(i, n)  for (int i = 0, _n = (n); i < _n; ++i)
#define          REPD(i, n)  for (int i = (n); i--; )
#define        FOR(i, a, b)  for (int i = (a), _b = (b); i < _b; ++i) 
#define       FORD(i, b, a)  for (int i = (b), _a = (a); --i >= _a; ) 
#define       FORE(i, a, b)  for (int i = (a), _b = (b); i <= _b; ++i) 
#define      FORDE(i, b, a)  for (int i = (b), _a = (a); i >= _a; --i) 
#define        scan_op(...)  istream & operator >> (istream &in, __VA_ARGS__ &u)
#define       print_op(...)  ostream & operator << (ostream &out, const __VA_ARGS__ &u)
#ifdef LOCAL
    #include "debug.h"
#else
    #define file(name) if (fopen(name".inp", "r")) { freopen(name".inp", "r", stdin); freopen(name".out", "w", stdout); }
    #define DB(...) 23
    #define db(...) 23
    #define debug(...) 23
#endif
 
template <class U, class V> scan_op(pair <U, V>)  { return in >> u.first >> u.second; }
template <class T> scan_op(vector <T>)  { for (size_t i = 0; i < u.size(); ++i) in >> u[i]; return in; }
template <class U, class V> print_op(pair <U, V>)  { return out << '(' << u.first << ", " << u.second << ')'; }
template <size_t i, class T> ostream & print_tuple_utils(ostream &out, const T &tup) { if constexpr(i == tuple_size<T>::value) return out << ")";  else return print_tuple_utils<i + 1, T>(out << (i ? ", " : "(") << get<i>(tup), tup); }
template <class ...U> print_op(tuple<U...>) { return print_tuple_utils<0, tuple<U...>>(out, u); }
template <class Con, class = decltype(begin(declval<Con>()))> typename enable_if <!is_same<Con, string>::value, ostream&>::type operator << (ostream &out, const Con &con) { out << '{'; for (__typeof(con.begin()) it = con.begin(); it != con.end(); ++it) out << (it == con.begin() ? "" : ", ") << *it; return out << '}'; }
template <class U, class V> bool maximize(U &a, V b) { if (a < b) { a = b; return true; } return false; }

using ii = pair <int, int>;

// end of template

// you may find the definitions useful
const int OP_ZERO    = 0;  // f(OP_ZERO,    x0, x1) = 0
const int OP_NOR     = 1;  // f(OP_NOR,     x0, x1) = !(x0 || x1)
const int OP_GREATER = 2;  // f(OP_GREATER, x0, x1) = (x0 > x1)
const int OP_NOT_X1  = 3;  // f(OP_NOT_X1,  x0, x1) = !x1
const int OP_LESS    = 4;  // f(OP_LESS,    x0, x1) = (x0 < x1)
const int OP_NOT_X0  = 5;  // f(OP_NOT_X0,  x0, x1) = !x0
const int OP_XOR     = 6;  // f(OP_XOR,     x0, x1) = (x0 ^ x1)
const int OP_NAND    = 7;  // f(OP_NAND,    x0, x1) = !(x0 && x1)
const int OP_AND     = 8;  // f(OP_AND,     x0, x1) = (x0 && x1)
const int OP_EQUAL   = 9;  // f(OP_EQUAL,   x0, x1) = (x0 == x1)
const int OP_X0      = 10; // f(OP_X0,      x0, x1) = x0
const int OP_GEQ     = 11; // f(OP_GEQ,     x0, x1) = (x0 >= x1)
const int OP_X1      = 12; // f(OP_X1,      x0, x1) = x1
const int OP_LEQ     = 13; // f(OP_LEQ,     x0, x1) = (x0 <= x1)
const int OP_OR      = 14; // f(OP_OR,      x0, x1) = (x0 || x1)
const int OP_ONE     = 15; // f(OP_ONE,     x0, x1) = 1
const int LOG = 16;

int encode(const char name[5]) {
    int res = 0;
    REP(i, 4) {
        if (!name[i]) break;
        res = res * 26 + name[i] - 'a' + 1;
    }
    return res - 1;
}

vector <int> traceSwap(vector <int> a) {
    int n = a.size();
    if (n <= 1) return vector<int>();
    vector <int> pos(n);
    REP(i, n) pos[a[i]] = i;
    vector <int> swap_first((n + 1) / 2, -1);
    int m = n / 2;
    auto dfs = [&] (int u) {
        while (~u) {
            int nxt = -1;
            for (int f: {0, m}) {
                if (u == m && !f) continue;
                int x = a[u + f];
                if (x == m + m) {
                    assert(swap_first[u] == !f);
                    continue;
                }
                int v = pos[x < m ? x + m : x - m];
                int w = !f;
                if (v >= m) {
                    v -= m;
                    w ^= 1;
                }
                if (swap_first[v] == -1) {
                    swap_first[v] = swap_first[u] ^ w;
                    nxt = v;
                    break;
                }
            }
            u = nxt;
        }
    };
    if (n & 1) {
        swap_first[m] = 0;
        dfs(m);
        swap_first.pop_back();
    }
    REP(i, m) if (!~swap_first[i]) {
        swap_first[i] = 0;
        dfs(i);
    }
    REP(i, m) if (swap_first[i]) swap(a[i], a[i + m]);
    vector <int> swap_last(m);
    REP(i, m) if (a[i] >= m) {
        a[i] -= m;
        swap_last[a[i]] = 1;
    }
    transform(m + ALL(a), a.begin() + m, [&] (int x) { return x >= m ? x - m : x; });
    vector <int> swap_left = traceSwap(vector<int>(a.begin(), a.begin() + m));
    vector <int> swap_right = traceSwap(vector<int>(m + ALL(a)));
    swap_first.insert(swap_first.end(), ALL(swap_left));
    swap_first.insert(swap_first.end(), ALL(swap_right));
    swap_first.insert(swap_first.end(), ALL(swap_last));
    return swap_first;
}

// Alice
int // returns la
alice(
    /*  in */ const int n,
    /*  in */ const char names[][5],
    /*  in */ const unsigned short numbers[],
    /* out */ bool outputs_alice[]
) {
    int l = 0;
    auto send = [&] (int value, int num_bit) {
        REP(i, num_bit) outputs_alice[l++] = BIT(value, i);
    };
    vector <pair <int, int>> order;
    REP(i, n) order.emplace_back(encode(names[i]), i);
    sort(ALL(order));

    vector <int> perm;
    for (auto [v, i]: order) {
        send(v, 19);
        send(numbers[i], 16);
        perm.push_back(i);
    }
    for (int x: traceSwap(perm)) send(x, 1);
    return l;
}


// Bob
int // returns lb
bob(
    /*  in */ const int m,
    /*  in */ const char senders[][5],
    /*  in */ const char recipients[][5],
    /* out */ bool outputs_bob[]
) {
    int l = 0;
    auto send = [&] (int value, int num_bit) {
        REP(i, num_bit) outputs_bob[l++] = BIT(value, i);
    };
    vector <tuple <int, int, int>> order;
    REP(i, m) order.emplace_back(encode(recipients[i]), encode(senders[i]), -1);
    sort(ALL(order));
    REP(i, m) {
        swap(get<0>(order[i]), get<1>(order[i]));
        get<2>(order[i]) = i;
    }
    sort(ALL(order));
    vector <int> perm;
    for (auto [s, r, i]: order) {
        send(s, 19);
        send(r, 19);
        perm.push_back(i);
    }
    for (int x: traceSwap(perm)) send(x, 1);
    return l;
}

void buildSwapNetwork(int l, int r, vector <ii> &res) {
    if (r - l <= 1) return;
    int m = l + r >> 1;
    FOR(i, l, m) res.emplace_back(i, i + m - l);
    buildSwapNetwork(l, m, res);
    buildSwapNetwork(m, r, res);
    FOR(i, l, m) res.emplace_back(i, i + m - l);
}

void buildMergeNetwork(int l, int r, int step, vector <ii> &res) {
    int step2 = step + step;
    if (l + step2 >= r) {
        res.emplace_back(l, l + step);
        return;
    }
    buildMergeNetwork(l, r, step2, res);
    buildMergeNetwork(l + step, r, step2, res);
    for (int i = l + step; i + step < r; i += step2)
        res.emplace_back(i, i + step);
}

int cache[1002];

int cntSwapNetwork(int n) {
    if (n <= 1 || cache[n]) return cache[n];
    int m = n / 2;
    return cache[n] = m + cntSwapNetwork(m) + cntSwapNetwork(n - m) + m;
}

struct Element {
    vector <int> fi, se, val;
    int type;
};

// Circuit
int // returns l
circuit(
    /*  in */ const int la,
    /*  in */ const int lb,
    /* out */ int operations[],
    /* out */ int operands[][2],
    /* out */ int out[][16]
) {
    int num_gate = la + lb;

    auto get_gate = [&] (int op, int a, int b) {
        if (a == -1 || b == -1) {
        }
        operations[num_gate] = op;
        operands[num_gate][0] = a;
        operands[num_gate][1] = b;
        return num_gate++;
    };

    int zero = get_gate(OP_ZERO, 0, 0);
    int one = get_gate(OP_ONE, 0, 0);

    auto _swap = [&] (int &x, int &y, int p) {
        int a = get_gate(OP_AND, p, get_gate(OP_XOR, x, y));
        x = get_gate(OP_XOR, x, a);
        y = get_gate(OP_XOR, y, a);
    };

    auto get_gate_2 = [&] (int op, const vector <int> &a, const vector <int> &b) {
        vector <int> res;
        REP(i, 16) res.push_back(get_gate(op, a[i], b[i]));
        return res;
    };

    auto get_gate_3 = [&] (int op, const vector <int> &a, int b) {
        vector <int> res;
        REP(i, 16) res.push_back(get_gate(op, a[i], b));
        return res;
    };

    auto sum = [&] (const vector <int> &a, const vector <int> &b) {
        int rem = zero;
        vector <int> res;
        REP(i, 16) {
            int sum = get_gate(OP_XOR, a[i], b[i]);
            int new_r = get_gate(OP_AND, a[i], b[i]);
            res.push_back(get_gate(OP_XOR, sum, rem));
            rem = get_gate(OP_OR, get_gate(OP_AND, a[i], b[i]), get_gate(OP_AND, sum, rem));
        }
        return res;
    };

    int n = 0;
    while (35 * n + cntSwapNetwork(n) < la) ++n;
    int m = 0;
    while (38 * m + cntSwapNetwork(m) < lb) ++m;

    int k = 1;
    while (k < max(m, n)) k <<= 1;
    vector <Element> elements(k + k);

    REP(i, n) {
        REP(j, 19) {
            elements[i].fi.push_back(35 * i + j);
            elements[i].se.push_back(35 * i + j);
        }
        REP(j, 16) elements[i].val.push_back(35 * i + 19 + j);
        elements[i].type = zero;
        elements.push_back(elements[i]);
    }

    REP(i, m) {
        REP(j, 19) {
            elements[i + k].fi.push_back(la + 38 * i + j);
            elements[i + k].se.push_back(la + 38 * i + 19 + j);
        }
        elements[i + k].val.assign(16, zero);
        elements[i + k].type = one;
    }

    vector <int> cmp;
    auto shuffle = [&] (const vector <ii> &network, bool just_val = false) {
        REP(i, network.size()) {
            auto [x, y] = network[i];
            if (cmp[i] == -1) {
                if (elements[x].fi.empty() || elements[y].fi.empty()) {
                    cmp[i] = !elements[y].fi.empty() ? -2 : -3;
                }
                else {
                    int equal = one;
                    cmp[i] = zero;
                    REPD(j, 19) {
                        int a = elements[x].fi[j], b = elements[y].fi[j];
                        assert(equal != -1);
                        cmp[i] = get_gate(OP_OR, cmp[i], get_gate(OP_AND, equal, get_gate(OP_GREATER, a, b)));
                        equal = get_gate(OP_AND, equal, get_gate(OP_EQUAL, a, b));
                    }
                    cmp[i] = get_gate(OP_OR, cmp[i], get_gate(OP_AND, equal, get_gate(OP_GREATER, elements[x].type, elements[y].type)));
                }
            }
            if (cmp[i] < 0) {
                if (cmp[i] == -2) swap(elements[x], elements[y]);
            } else {
                if (!just_val) {
                    REP(j, 19) {
                        _swap(elements[x].fi[j], elements[y].fi[j], cmp[i]);
                        _swap(elements[x].se[j], elements[y].se[j], cmp[i]);
                    }
                    _swap(elements[x].type, elements[y].type, cmp[i]);
                }
                REP(j, 16) _swap(elements[x].val[j], elements[y].val[j], cmp[i]);
            }
        }
    };

    vector <ii> merge_network, swap_network;
    buildMergeNetwork(0, k + k, 1, merge_network);
    cmp.assign(merge_network.size(), -1);
    shuffle(merge_network);

    vector <int> cur(16, zero);

    REP(i, n + m) {
        elements[i].val = cur = get_gate_2(OP_XOR, cur, get_gate_3(OP_GREATER, get_gate_2(OP_XOR, cur, elements[i].val), elements[i].type));
        swap(elements[i].fi, elements[i].se);
    }

    reverse(ALL(merge_network));
    reverse(ALL(cmp));
    shuffle(merge_network);
    reverse(ALL(merge_network));
    cmp.clear();

    buildSwapNetwork(k, k + m, swap_network);
    REP(i, swap_network.size()) cmp.push_back(la + 38 * m + i);
    shuffle(swap_network);

    cmp.assign(merge_network.size(), -1);
    shuffle(merge_network);

    cur.assign(16, zero);
    REPD(i, n + m) {
        swap(cur, elements[i].val);
        cur = get_gate_3(OP_AND, sum(cur, elements[i].val), elements[i].type);
    }

    reverse(ALL(merge_network));
    reverse(ALL(cmp));
    shuffle(merge_network, true);

    swap_network.clear();
    buildSwapNetwork(0, n, swap_network);
    cmp.clear();
    REP(i, swap_network.size()) cmp.push_back(35 * n + i);
    shuffle(swap_network, true);

    REP(i, n) REP(j, 16) out[i][j] = elements[i].val[j];
    return num_gate;
}

Compilation message

abc.cpp: In function 'void buildSwapNetwork(int, int, std::vector<std::pair<int, int> >&)':
abc.cpp:187:15: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
  187 |     int m = l + r >> 1;
      |             ~~^~~
abc.cpp: In lambda function:
abc.cpp:265:17: warning: unused variable 'new_r' [-Wunused-variable]
  265 |             int new_r = get_gate(OP_AND, a[i], b[i]);
      |                 ^~~~~
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 1196 KB Correct!
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 1196 KB Correct!
2 Correct 2 ms 1272 KB Correct!
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 1196 KB Correct!
2 Correct 2 ms 1272 KB Correct!
3 Correct 238 ms 137840 KB Correct!
4 Correct 237 ms 138180 KB Correct!
# 결과 실행 시간 메모리 Grader output
1 Correct 8 ms 4184 KB Correct!
2 Correct 128 ms 69540 KB Correct!
3 Correct 183 ms 93816 KB Correct!
# 결과 실행 시간 메모리 Grader output
1 Correct 8 ms 4184 KB Correct!
2 Correct 128 ms 69540 KB Correct!
3 Correct 183 ms 93816 KB Correct!
4 Correct 138 ms 69632 KB Correct!
5 Correct 184 ms 93684 KB Correct!
# 결과 실행 시간 메모리 Grader output
1 Correct 8 ms 4184 KB Correct!
2 Correct 128 ms 69540 KB Correct!
3 Correct 183 ms 93816 KB Correct!
4 Correct 138 ms 69632 KB Correct!
5 Correct 184 ms 93684 KB Correct!
6 Correct 120 ms 59404 KB Correct!
7 Correct 242 ms 128712 KB Correct!
# 결과 실행 시간 메모리 Grader output
1 Correct 466 ms 240988 KB Correct!
2 Correct 491 ms 241656 KB Correct!
# 결과 실행 시간 메모리 Grader output
1 Correct 466 ms 240988 KB Correct!
2 Correct 491 ms 241656 KB Correct!
3 Correct 452 ms 221944 KB Correct!
4 Correct 480 ms 241492 KB Correct!
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 1196 KB Correct!
2 Correct 2 ms 1272 KB Correct!
3 Correct 238 ms 137840 KB Correct!
4 Correct 237 ms 138180 KB Correct!
5 Correct 8 ms 4184 KB Correct!
6 Correct 128 ms 69540 KB Correct!
7 Correct 183 ms 93816 KB Correct!
8 Correct 138 ms 69632 KB Correct!
9 Correct 184 ms 93684 KB Correct!
10 Correct 120 ms 59404 KB Correct!
11 Correct 242 ms 128712 KB Correct!
12 Correct 466 ms 240988 KB Correct!
13 Correct 491 ms 241656 KB Correct!
14 Correct 452 ms 221944 KB Correct!
15 Correct 480 ms 241492 KB Correct!
16 Correct 490 ms 244612 KB Correct!
17 Correct 494 ms 245020 KB Correct!