답안 #1075009

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1075009 2024-08-25T17:32:38 Z GusterGoose27 Ancient Machine (JOI21_ancient_machine) C++17
100 / 100
66 ms 10448 KB
#include "Anna.h"

#include <bits/stdc++.h>
using namespace std;

#define sz(s) (int(s.size()))

typedef unsigned int ui;
typedef unsigned long long ll;
typedef long double ld;

namespace anna {

const int enc_bits = 32;

vector<ll> optimal_probs({1640535876, 1640535876, 1013895539, 5});

class encoder {
public:
    vector<ll> probs;
    vector<ll> psum;
    int ef;
    encoder(vector<ll> &p, int ef) : ef(ef) {
        psum.push_back(0);
        for (int v: p) {
            probs.push_back(v);
            psum.push_back(psum.back()+v);
        }
        assert(psum.back() == (1ll<<enc_bits));
    }

    ll shift(ll v) {
        return (v<<1)&((1ll<<enc_bits)-1);
    }
    ll shift2(ll v) {
        return ((v<<1)&((1ll<<(enc_bits-1))-1))+(v&(1ll<<(enc_bits-1)));
    }
    void encode(vector<int> &msg, vector<int> &out) {
        msg.push_back(ef);
        ll low = 0, high = (1ll<<enc_bits)-1;
        int pend = 0;
        ll div = (1ll<<(enc_bits-1));
        auto ppend = [&](bool b) {
            out.push_back(b);
            while (pend) {
                out.push_back(!b);
                pend--;
            }
        };
        for (int v: msg) {
            ll range = (ll)high-low+1;
            high = low + ((range*psum[v+1])>>enc_bits)-1;
            low = low + ((range*psum[v])>>enc_bits);
            while (1) {
                if (high < div) {
                    ppend(0);
                    high = shift(high);
                    low = shift(low);
                }
                else if (low >= div) {
                    ppend(1);
                    high = shift(high);
                    low = shift(low);
                }
                else if (low >= div/2 && high < 3*div/2) {
                    pend++;
                    low = shift2(low);
                    high = shift2(high);
                }
                else break;
            }
            if (v == ef) break;
        }
        pend++;
        if (low < div/2) {
            ppend(0);
        }
        else {
            ppend(1);
        }
    }

    int get_char(ll r, ll v) {
        assert(v < r);
        assert(v >= 0);
        for (int i = 0; i < sz(probs); i++) {
            if (((psum[i+1]*r)>>enc_bits) > v) return i;
        }
        assert(false);
    }

    void decode(vector<int> &inp, vector<int> &out) {
        for (int i = 0; i < enc_bits; i++) inp.push_back(0);
        ll value = 0;
        int p = 0;
        for (int i = 0; i < enc_bits; i++) value = (value<<1)+inp[p++];
        ll low = 0, high = (1ll<<enc_bits)-1;
        int pend = 0;
        ll div = (1ll<<(enc_bits-1));
        for (int i = 0; i < sz(inp); i++) {
            ll range = (ll)high-low+1;
            int v = get_char(range, value-low);
            if (v != ef) out.push_back(v);
            high = low + ((range*psum[v+1])>>enc_bits)-1;
            low = low + ((range*psum[v])>>enc_bits);
            while (1) {
                if (high < div) {
                    high = shift(high);
                    low = shift(low);
                    value = shift(value);
                }
                else if (low >= div) {
                    high = shift(high);
                    low = shift(low);
                    value = shift(value);
                }
                else if (low >= div/2 && high < 3*div/2) { 
                    high = shift2(high);
                    low = shift2(low);
                    value = shift2(value);
                }
                else break;
                value += inp[p++];
            }
            if (v == ef) break;
        }
    }
};

void print(vector<int> &v) {
    for (int u: v) cerr << u << ' ';
    cerr << '\n';
}

}

void Anna(int n, vector<char> S) {
    vector<int> conv;
    int p = 1;
    for (int i = 0; i < n && S[i] != 'X'; i++) S[i] = 'Y';
    int r = n-1;
    for (; r >= 0 && S[r] != 'Z'; r--) S[r] = 'Y';
    if (r < 2) return;
    for (int i = r-1; i > 0 && S[i] != 'Y'; i--) S[i] = 'X';
    for (char c: S) {
        int v = c-'X';
        if ((v&1)==(p&1)) conv.push_back(1);
        else {
            conv.push_back(v);
            p = v;
        }
    }
    conv.resize(r);
    conv.push_back(1);
    vector<int> gold_code; // 0 -> XY, 1 -> ZY, 2 -> YXY
    for (int i = 0; i < sz(conv);) {
        if (conv[i] != 1) {
            assert(i < sz(conv)-1);
            assert(conv[i+1] == 1);
            gold_code.push_back(conv[i]/2);
            i += 2;
        }
        else {
            if (i == sz(conv)-1) break;
            if (conv[i+1] == 1) {
                conv[i] = 2;
                gold_code.push_back(1);
                i += 2;
            }
            else if (conv[i+1] == 2) {
                swap(conv[i], conv[i+1]);
                gold_code.push_back(1);
                i += 2;
            }
            else {
                assert(i < sz(conv)-2);
                assert(conv[i+2] == 1);
                gold_code.push_back(2);
                i += 3;
            }
        }
    }
    // anna::print(gold_code);
    bool ex = (conv.end()[-2] == 1);
    if (ex) conv.pop_back();
    vector<int> to_send;
    anna::encoder enc(anna::optimal_probs, 3);
    enc.encode(gold_code, to_send);
    for (int v: to_send) Send(v);
    Send(ex);
}
#include "Bruno.h"
#include <bits/stdc++.h>

using namespace std;
typedef unsigned long long ll;

#define sz(s) (int(s.size()))

namespace bruno {

const int enc_bits = 32;

vector<ll> optimal_probs({1640535876, 1640535876, 1013895539, 5});

class encoder {
public:
    vector<ll> probs;
    vector<ll> psum;
    int ef;
    encoder(vector<ll> &p, int ef) : ef(ef) {
        psum.push_back(0);
        for (int v: p) {
            probs.push_back(v);
            psum.push_back(psum.back()+v);
        }
        assert(psum.back() == (1ll<<enc_bits));
    }

    ll shift(ll v) {
        return (v<<1)&((1ll<<enc_bits)-1);
    }
    ll shift2(ll v) {
        return ((v<<1)&((1ll<<(enc_bits-1))-1))+(v&(1ll<<(enc_bits-1)));
    }
    void encode(vector<int> &msg, vector<int> &out) {
        msg.push_back(ef);
        ll low = 0, high = (1ll<<enc_bits)-1;
        int pend = 0;
        ll div = (1ll<<(enc_bits-1));
        auto ppend = [&](bool b) {
            out.push_back(b);
            while (pend) {
                out.push_back(!b);
                pend--;
            }
        };
        for (int v: msg) {
            ll range = (ll)high-low+1;
            high = low + ((range*psum[v+1])>>enc_bits)-1;
            low = low + ((range*psum[v])>>enc_bits);
            while (1) {
                if (high < div) {
                    ppend(0);
                    high = shift(high);
                    low = shift(low);
                }
                else if (low >= div) {
                    ppend(1);
                    high = shift(high);
                    low = shift(low);
                }
                else if (low >= div/2 && high < 3*div/2) {
                    pend++;
                    low = shift2(low);
                    high = shift2(high);
                }
                else break;
            }
            if (v == ef) break;
        }
        pend++;
        if (low < div/2) {
            ppend(0);
        }
        else {
            ppend(1);
        }
    }

    int get_char(ll r, ll v) {
        assert(v < r);
        assert(v >= 0);
        for (int i = 0; i < sz(probs); i++) {
            if (((psum[i+1]*r)>>enc_bits) > v) return i;
        }
        assert(false);
    }

    void decode(vector<int> &inp, vector<int> &out) {
        for (int i = 0; i < enc_bits; i++) inp.push_back(0);
        ll value = 0;
        int p = 0;
        for (int i = 0; i < enc_bits; i++) value = (value<<1)+inp[p++];
        ll low = 0, high = (1ll<<enc_bits)-1;
        int pend = 0;
        ll div = (1ll<<(enc_bits-1));
        for (int i = 0; i < sz(inp); i++) {
            ll range = (ll)high-low+1;
            int v = get_char(range, value-low);
            if (v != ef) out.push_back(v);
            high = low + ((range*psum[v+1])>>enc_bits)-1;
            low = low + ((range*psum[v])>>enc_bits);
            while (1) {
                if (high < div) {
                    high = shift(high);
                    low = shift(low);
                    value = shift(value);
                }
                else if (low >= div) {
                    high = shift(high);
                    low = shift(low);
                    value = shift(value);
                }
                else if (low >= div/2 && high < 3*div/2) { 
                    high = shift2(high);
                    low = shift2(low);
                    value = shift2(value);
                }
                else break;
                value += inp[p++];
            }
            if (v == ef) break;
        }
    }
};

void print(vector<int> &v) {
    for (int u: v) cerr << u << ' ';
    cerr << '\n';
}

}

void Bruno(int n, int l, vector<int> A) {
    if (sz(A) == 0) {
        for (int i = 0; i < n; i++) Remove(i);
        return;
    }
    bool ex = A.back();
    A.pop_back();
    vector<int> decoded; // 0 -> XY, 1 -> ZY, 2 -> YXY
    bruno::encoder enc(bruno::optimal_probs, 3);
    enc.decode(A, decoded);
    // bruno::print(decoded);
    vector<int> tp;
    for (int v: decoded) {
        if (v == 0) {
            tp.push_back(0);
        }
        else if (v == 1) {
            tp.push_back(2);
        }
        else {
            tp.push_back(1);
            tp.push_back(0);
        }
        tp.push_back(1);
    }
    int r = sz(tp)-1 + ex;
    if (ex) tp.push_back(2);
    else tp[r] = 2;
    set<int> xz, y;
    while (sz(tp) < n) tp.push_back(1);
    for (int i = 0; i < n; i++) {
        if (tp[i]&1) y.insert(i);
        else xz.insert(i);
    }
    auto del = [&](int v) {
        if (tp[v]&1) {
            if (y.find(v) == y.end()) return;
            y.erase(v);
        }
        else {
            if (xz.find(v) == xz.end()) return;
            xz.erase(v);
        }
        Remove(v);
    };
    if (sz(xz) > 1) {
        int lim = *(xz.begin());
        for (int s = *(xz.begin()); s != *(xz.rbegin()); ) {
            int nxt = *(xz.upper_bound(s));
            if (nxt == r) break;
            if (tp[nxt] == 0) {
                s = nxt;
                lim = nxt;
                continue;
            }
            for (int i = lim+1; i <= nxt; i++) del(i);
            lim = nxt;
        }
        int s = *(xz.rbegin());
        lim = s;
        while (sz(xz) > 1) {
            int p = *(--xz.find(s));
            for (int i = p+1; i < lim; i++) del(i);
            del(p);
            lim = p;
        }
    }
    for (int i = 0; i < n; i++) del(i);
}

Compilation message

Anna.cpp: In member function 'void anna::encoder::decode(std::vector<int>&, std::vector<int>&)':
Anna.cpp:98:13: warning: unused variable 'pend' [-Wunused-variable]
   98 |         int pend = 0;
      |             ^~~~

Bruno.cpp: In member function 'void bruno::encoder::decode(std::vector<int>&, std::vector<int>&)':
Bruno.cpp:95:13: warning: unused variable 'pend' [-Wunused-variable]
   95 |         int pend = 0;
      |             ^~~~
# 결과 실행 시간 메모리 Grader output
1 Correct 0 ms 784 KB Output is correct
2 Correct 0 ms 784 KB Output is correct
3 Correct 2 ms 784 KB Output is correct
4 Correct 0 ms 796 KB Output is correct
5 Correct 1 ms 792 KB Output is correct
6 Correct 0 ms 784 KB Output is correct
7 Correct 1 ms 780 KB Output is correct
8 Correct 1 ms 800 KB Output is correct
9 Correct 0 ms 796 KB Output is correct
10 Correct 1 ms 796 KB Output is correct
11 Correct 0 ms 784 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 65 ms 10016 KB Output is correct
2 Correct 53 ms 9776 KB Output is correct
3 Correct 55 ms 9948 KB Output is correct
4 Correct 55 ms 9868 KB Output is correct
5 Correct 55 ms 9776 KB Output is correct
6 Correct 56 ms 9832 KB Output is correct
7 Correct 55 ms 9976 KB Output is correct
8 Correct 66 ms 9712 KB Output is correct
9 Correct 55 ms 9716 KB Output is correct
10 Correct 56 ms 9840 KB Output is correct
11 Correct 56 ms 9928 KB Output is correct
12 Correct 57 ms 9944 KB Output is correct
13 Correct 63 ms 9852 KB Output is correct
14 Correct 61 ms 9944 KB Output is correct
15 Correct 58 ms 9928 KB Output is correct
16 Correct 58 ms 9596 KB Output is correct
17 Correct 31 ms 6924 KB Output is correct
18 Correct 31 ms 6840 KB Output is correct
19 Correct 31 ms 6964 KB Output is correct
20 Correct 58 ms 9852 KB Output is correct
21 Correct 61 ms 9800 KB Output is correct
22 Correct 59 ms 9720 KB Output is correct
23 Correct 53 ms 9896 KB Output is correct
24 Correct 56 ms 9632 KB Output is correct
25 Correct 32 ms 6864 KB Output is correct
26 Correct 31 ms 6896 KB Output is correct
27 Correct 39 ms 6908 KB Output is correct
28 Correct 34 ms 6912 KB Output is correct
29 Correct 32 ms 6904 KB Output is correct
30 Correct 31 ms 6836 KB Output is correct
31 Correct 30 ms 6868 KB Output is correct
32 Correct 56 ms 10448 KB Output is correct
33 Correct 56 ms 9832 KB Output is correct