제출 #1176036

#제출 시각아이디문제언어결과실행 시간메모리
1176036TrendBattlesSnake Escaping (JOI18_snake_escaping)C++17
22 / 100
175 ms23796 KiB
#include <bits/stdc++.h>
using namespace std;
using lli = long long int;

#define INFILE "snake.inp"
#define OUTFILE "snake.out"


namespace SUBTASK_small_L {
    void main(int L, int Q, string digits) {
        vector <int> pow_3(L + 1);
        pow_3[0] = 1;
        for (int i = 1; i <= L; ++i) {
            pow_3[i] = pow_3[i - 1] * 3;
        }

        const int M = pow_3[L];
        vector <int> mask_3(1 << L);
        vector <int> finale(M);

        for (int m = 1; m < (1 << L); ++m) {
            int p = m & -m;

            mask_3[m] = mask_3[m ^ p] + pow_3[__builtin_ctz(p)];
        }

        for (int m = 0; m < (1 << L); ++m) {
            int x = digits[m] - '0';

            for (int p = 0; p < (1 << L); ++p) {
                int f_m = mask_3[p] * 2 + mask_3[m & ~p];

                finale[f_m] += x;
            }
        }

        for (int _ = 0; _ < Q; ++_) {
            string que; cin >> que;
            int m = 0;
            for (int i = 0; i < L; ++i) {
                if (que[i] == '1') m += pow_3[L - i - 1];
                else if (que[i] == '?') m += pow_3[L - i - 1] * 2;
            }
            cout << finale[m] << '\n';
        }
    }
}

// namespace SUBTASK_small_Q {
//     const int HIGH_BIT = 20, MAX_MASK = 1 << HIGH_BIT;
//     int sum_A[MAX_MASK], sum_B[MAX_MASK];
//     int bit_cnt[MAX_MASK];
//     void main(int L, int Q, string digits) {
//         for (int m = 0; m < (1 << L); ++m) {
//             sum_A[m] = digits[m] - '0';
//             sum_B[~(-1 << L) ^ m] = sum_A[m];

//             bit_cnt[m] = bit_cnt[m >> 1] ^ (m & 1);
//         }
        
//         for (int b = 0; b < L; ++b) {
//             for (int m = 0; m < (1 << L); ++m) {
//                 if ((m >> b & 1) == 0) {
//                     sum_A[m] += sum_A[m ^ (1 << b)];
//                     sum_B[m] += sum_B[m ^ (1 << b)];
//                 }
//             }
//         }

//         for (int _ = 0; _ < Q; ++_) {
//             string que; cin >> que;
//             int ans = 0;
//             if (count(que.begin(), que.end(), '0') <= L / 2) {
//                 int allowed = 0, banned = 0;
//                 for (int i = 0; i < L; ++i) {
//                     if (que[i] == '1') allowed ^= 1 << (L - i - 1);
//                     if (que[i] == '0') banned ^= 1 << (L - i - 1);
//                 }

//                 for (int p = banned; p >= 0; p = (p - 1) & banned) {
//                     if (bit_cnt[p]) {
//                         ans -= sum_A[allowed ^ p];
//                     } else {
//                         ans += sum_A[allowed ^ p];
//                     }

//                     if (p == 0) break;
//                 }
//             } else {
//                 int allowed = 0, banned = 0;
//                 for (int i = 0; i < L; ++i) {
//                     if (que[i] == '0') allowed ^= 1 << (L - i - 1);
//                     if (que[i] == '1') banned ^= 1 << (L - i - 1);
//                 }

//                 for (int p = banned; p >= 0; p = (p - 1) & banned) {
//                     if (bit_cnt[p]) {
//                         ans -= sum_B[allowed ^ p];
//                     } else {
//                         ans += sum_B[allowed ^ p];
//                     }

//                     if (p == 0) break;
//                 }
//             }
//             cout << ans << '\n';
//         }
//     }
// }

namespace SUBTASK_main {
    const int MAX_BIT = 20, SMALL_BIT = 4, SMALL_SIZE = 81;
    const int SMALL_MAGIC = (1 << SMALL_BIT) - 1;

    int sum[1 << (MAX_BIT - SMALL_BIT)][SMALL_SIZE];
    int pow_3[SMALL_BIT + 1], mask_3[1 << SMALL_BIT];

    const int BIG_SIZE = 1 << (MAX_BIT - SMALL_BIT);
    int8_t bit_cnt[BIG_SIZE];

    int16_t question_mask[1 << MAX_BIT], zero_mask[1 << MAX_BIT], state[1 << MAX_BIT];
    int finale[1 << MAX_BIT];


    
    void main(int L, int Q, string digits) {
        for (int m = 0; m < BIG_SIZE; ++m) {
            bit_cnt[m] = bit_cnt[m >> 1] ^ (m & 1);
        }

        pow_3[0] = 1;
        for (int i = 1; i <= SMALL_BIT; ++i) {
            pow_3[i] = pow_3[i - 1] * 3;
        }
        for (int m = 1; m < (1 << SMALL_BIT); ++m) {
            int p = m & -m;
            mask_3[m] = mask_3[m ^ p] + pow_3[__builtin_ctz(p)];
        }

        for (int pos = 0; pos < (1 << L); ++pos) {
            int a = pos >> SMALL_BIT, b = pos & SMALL_MAGIC;
            int x = digits[pos] ^ '0';
            for (int apply = 0; apply <= SMALL_MAGIC; ++apply) {
                sum[a][(mask_3[apply] << 1) + mask_3[b & ~apply]] += x;
            }
        }

        for (int query = 0; query < Q; ++query) {
            string que; cin >> que;
            if (count(que.begin(), que.begin() + L - SMALL_BIT, '?') <= (L - SMALL_BIT) / 2) {
                question_mask[query] = zero_mask[query] = state[query] = -1;

                int fixed_mask = 0, change_mask = 0;
                for (int i = 0; i < L - SMALL_BIT; ++i) {
                    if (que[i] == '1') fixed_mask ^= 1 << L - i - 1 - SMALL_BIT;
                    if (que[i] == '?') change_mask ^= 1 << L - i - 1 - SMALL_BIT;
                }
                
                int s = 0;
                
                for (int i = L - SMALL_BIT; i < L; ++i) {
                    if (que[i] == '1') s += pow_3[L - i - 1];
                    if (que[i] == '?') s += pow_3[L - i - 1] << 1;
                }
                
                for (int p = change_mask; p >= 0; p = (p - 1) & change_mask) {
                    finale[query] += sum[fixed_mask ^ p][s];

                    if (p == 0) break;
                }
                
                continue;
            }

            for (int i = 0; i < L - SMALL_BIT; ++i) {
                if (que[i] == '0') zero_mask[query] ^= 1 << L - i - 1 - SMALL_BIT;
                if (que[i] == '?') question_mask[query] ^= 1 << L - i - 1 - SMALL_BIT;
            }
            for (int i = L - SMALL_BIT; i < L; ++i) {
                if (que[i] == '1') state[query] += pow_3[L - i - 1];
                if (que[i] == '?') state[query] += pow_3[L - i - 1] << 1;
            }
        }

        for (int b = 0; b < L - SMALL_BIT; ++b) {
            for (int m = 0; m < (1 << L - SMALL_BIT); ++m) {
                if (m >> b & 1) continue;

                for (int s = 0; s < SMALL_SIZE; ++s) {
                    sum[m][s] += sum[m ^ (1 << b)][s];
                }
            }
        }

        for (int query = 0; query < Q; ++query) {
            if (question_mask[query] != -1) {
                int full = ~(-1 << L - SMALL_BIT) ^ question_mask[query] ^ zero_mask[query]; 

                int zero = zero_mask[query];
                for (int p = zero; p >= 0; p = (p - 1) & zero) {
                    if (bit_cnt[p]) {
                        finale[query] -= sum[full ^ p][state[query]];
                    } else {
                        finale[query] += sum[full ^ p][state[query]];
                    }

                    if (p == 0) break;
                }
            }

            cout << finale[query] << '\n';
        }
    }
}

int main() {
    ios::sync_with_stdio(0); cin.tie(0);
    if (fopen(INFILE, "r")) {
        freopen(INFILE, "r", stdin);
        freopen(OUTFILE, "w", stdout);
    }

    int L, Q; cin >> L >> Q;
    string digits; cin >> digits;
    
    if (L <= 10) {
        SUBTASK_small_L::main(L, Q, digits);
        return 0;
    }
    // if (Q <= 50'000) {
    //     SUBTASK_small_Q::main(L, Q, digits);
    //     return 0;
    // }

    SUBTASK_main::main(L, Q, digits);
    return 0;
}

컴파일 시 표준 에러 (stderr) 메시지

snake_escaping.cpp: In function 'int main()':
snake_escaping.cpp:219:16: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
  219 |         freopen(INFILE, "r", stdin);
      |         ~~~~~~~^~~~~~~~~~~~~~~~~~~~
snake_escaping.cpp:220:16: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
  220 |         freopen(OUTFILE, "w", stdout);
      |         ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...