답안 #754821

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
754821 2023-06-08T16:49:05 Z Stickfish Costinland (info1cup19_costinland) C++17
100 / 100
4 ms 340 KB
#include <iostream>
#include <vector>
#include <bitset>
#include <algorithm>
#include <set>
using ll = long long;
using namespace std;

pair<vector<int>, vector<int>> get_dp(int n, int m, int msk) {
    vector<vector<pair<int, int>>> dp(n + 1, vector<pair<int, int>>(m + 1, {0, 0}));
    dp[0][0] = {1, 0};
    for (int i = 0; i < n; ++i) for (int j = 0; j < m; ++j) {
        if (msk & (1 << (i * m + j))) {
            dp[i + 1][j].first += dp[i][j].first + dp[i][j].second;
            dp[i][j + 1].second += dp[i][j].first + dp[i][j].second;
        } else {
            dp[i + 1][j].first += dp[i][j].first;
            dp[i][j + 1].second += dp[i][j].second;
        }
    }
    vector<int> row(n);
    vector<int> col(m);
    for (int i = 0; i < n; ++i)
        row[i] = dp[i][m].second;
    for (int j = 0; j < m; ++j)
        col[j] = dp[n][j].first;
    return {row, col};
}

void solve_smallk(int k) {
    for (int m = 1; m < (1 << 16); m += 2) {
        vector<bitset<4>> v(4);
        for (int i = 0; i < 4; ++i) for (int j = 0; j < 4; ++j) {
            v[i][j] = m & (1 << (i * 4 + j));
        }
        int ans = 0;
        auto [row, col] = get_dp(4, 4, m);
        for (auto x : row)
            ans += x;
        for (auto x : col)
            ans += x;
        if (ans == k) {
            cout << "5 5\n";
            for (int i = 0; i < 4; ++i) {
                for (int j = 0; j < 4; ++j) {
                    cout << (v[i][j] ? 'X' : '.');
                }
                cout << "d\n";
            }
            cout << "rrrr.\n";
            return;
        }
    }
    cout << "-1\n";
}

struct cell {
    int i;
    int j;
    int t;

    cell() {}

    cell(int i_, int j_, int t_): i(i_), j(j_), t(t_) {}

    cell reverse() {
        return {j, i, (3 - t) % 3};
    }
};

struct block {
    int n;
    int m;
    ll d;
    ll r;
    vector<cell> cells;

    block() {};

    block(int n_, int m_, ll d_, ll r_, vector<cell> cells_): n(n_), m(m_), d(d_), r(r_), cells(cells_) {}

    block(int n_, int m_, ll d_, ll r_, string s): n(n_), m(m_), d(d_), r(r_) {
        for (int i = 0; i <= n; ++i) for (int j = 0; j <= m; ++j) {
            char c = s[i * (m + 1) + j];
            if (c == 'X')
                cells.push_back({i, j, 0});
            else if (c == 'd')
                cells.push_back({i, j, 1});
            else if (c == 'r')
                cells.push_back({i, j, 2});
        }
    }
};

vector<block> blocks;

vector<cell> solve_recurs(ll k, int n, int m) {
    if (n <= 0 || m <= 0)
        return {};
    if (k == 1)
        return {{0, 0, 1}};
    if (n > m) {
        vector<cell> ans = solve_recurs(k, m , n);
        for (auto& x : ans)
            x = x.reverse();
        return ans;
    }

    for (auto b : blocks) {
        if (k <= b.r || (k - b.r) % b.d || n < b.n || m < b.m)
            continue;
        vector<cell> ans = solve_recurs((k - b.r) / b.d, n - b.n, m - b.m);
        if (ans.empty())
            return {};
        //cout << "{" << n << ' ' << m << ' ' << k << " : " << b.n << ' ' << b.m << ' ' << b.d << ' ' << b.r << "}\n";
        for (auto& x : ans)
            x.i += b.n, x.j += b.m;
        for (auto x : b.cells)
            ans.push_back(x);
        return ans;
    }

    vector<cell> ans = solve_recurs(k - 1, n, m - 1);
    if (ans.empty())
        return {};
    for (auto& x : ans)
        ++x.j;
    ans.push_back({0, 0 ,0});
    return ans;
}

void init_blocks() {
    set<pair<int, int>> st;
    for (int sz = 2; sz < 7; ++sz) for (int n = 1; n < sz; ++n) {
        int m = sz - n;
        for (int msk = 1; msk < (1 << (n * m)); msk += 2) {
            auto [row, col] = get_dp(n, m, msk);
            for (int i = 1; i < n; ++i)
                row[i] += row[i - 1];
            for (int j = 1; j < m; ++j)
                col[j] += col[j - 1];
            int bld = row[n - 1] + col[m - 1];
            for (int rmsk = 0; rmsk < (1 << n); ++rmsk) for (int cmsk = 0; cmsk < (1 << m); ++cmsk) {
                int blr = 0;
                for (int i = 0; i < n; ++i) if (rmsk & (1 << i))
                    blr += row[i];
                for (int i = 0; i < m; ++i) if (cmsk & (1 << i))
                    blr += col[i];
                if (st.find({bld, blr % bld}) != st.end()) {
                    continue;
                }
                st.insert({bld, blr % bld});
                block bl = {n, m, bld, blr, ""};
                for (int i = 0; i < n; ++i) {
                    for (int j = 0; j < m; ++j) {
                        if (msk & (1 << (i * m + j)))
                            bl.cells.push_back({i, j, 0});
                    }
                    if (rmsk & (1 << i))
                        bl.cells.push_back({i, m, 0});
                    else
                        bl.cells.push_back({i, m, 1});
                }
                for (int j = 0; j < m; ++j) {
                    if (cmsk & (1 << j))
                        bl.cells.push_back({n, j, 0});
                    else
                        bl.cells.push_back({n, j, 2});
                }
                blocks.push_back(bl);
            }
        }
    }
    reverse(blocks.begin(), blocks.end());
    //blocks.push_back({
        //1, 2, 3, 2,
        //"XXXXr."
    //});
    //blocks.push_back({
        //1, 2, 3, 1,
        //"XXdXr."
    //});
    //blocks.push_back({
        //1, 2, 3, 0,
        //"XXdrr."
    //});
    //blocks.push_back({
        //1, 1, 2, 0,
        //"Xdr."
    //});
}

signed main() {
    ll k;
    cin >> k;
    if (k <= 19) {
        solve_smallk(k);
        return 0;
    }
    init_blocks();
    vector<char> symbols = {'.', 'X', 'd', 'r'};
    for (int n = 1; n <= 200; ++n) {
        vector<cell> ans = solve_recurs(k, n - 1, n - 1);
        if (ans.empty())
            continue;
        vector<vector<int>> v(n, vector<int>(n, -1));
        for (int i = 0; i < n - 1; ++i)
            v[i][n - 1] = 1, v[n - 1][i] = 2;
        for (auto x : ans) if (x.t != -6)
            v[x.i][x.j] = x.t;
        
        cout << n << ' ' << n << '\n';
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                cout << symbols[v[i][j] + 1];
            }
            cout << '\n';
        }
        return 0;
    }
}
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 212 KB Correct! Your size: 5
2 Correct 0 ms 212 KB Correct! Your size: 5
3 Correct 0 ms 212 KB Correct! Your size: 5
4 Correct 0 ms 212 KB Correct! Your size: 5
5 Correct 0 ms 212 KB Correct! Your size: 5
6 Correct 1 ms 212 KB Correct! Your size: 5
7 Correct 1 ms 212 KB Correct! Your size: 5
8 Correct 1 ms 212 KB Correct! Your size: 5
9 Correct 1 ms 212 KB Correct! Your size: 5
# 결과 실행 시간 메모리 Grader output
1 Correct 3 ms 340 KB Correct! Your size: 43
2 Correct 4 ms 308 KB Correct! Your size: 44
3 Correct 3 ms 340 KB Correct! Your size: 44
4 Correct 3 ms 340 KB Correct! Your size: 44
5 Correct 3 ms 340 KB Correct! Your size: 44
6 Correct 3 ms 340 KB Correct! Your size: 45
7 Correct 3 ms 340 KB Correct! Your size: 44
8 Correct 3 ms 340 KB Correct! Your size: 45