답안 #683421

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
683421 2023-01-18T10:56:41 Z nifeshe 힘 센 거북 (IZhO11_turtle) C++17
75 / 100
786 ms 24300 KB
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

#pragma GCC target ("avx2")
#pragma GCC optimize ("O3")
#pragma GCC optimize ("unroll-loops")
#pragma comment (linker, "/STACK: 16777216")

#define f first
#define s second
#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin(), (x).rend()
#define sz(x) ((int)(x).size())
#define pb push_back
#define mp make_pair
#define int long long

using namespace std;
using namespace __gnu_pbds;

template <typename T> inline bool umax(T &a, const T &b) { if(a < b) { a = b; return 1; } return 0; }
template <typename T> inline bool umin(T &a, const T &b) { if(a > b) { a = b; return 1; } return 0; }
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
template <typename T> using oset = tree<T, null_type, less <T>, rb_tree_tag, tree_order_statistics_node_update>;

ll mod = 998244353;
const ll base = 1e6 + 5;
const ll inf = 1e18;
const int MAX = 1e6;
const int lg = 20;

random_device rd;
mt19937 gen(rd());
uniform_int_distribution<ll> dis(1, inf);

int binpow(int x, int n) {
    int ans = 1;
    while(n) {
        if(n & 1) ans = (ans * x) % mod;
        n /= 2;
        x = (x * x) % mod;
    }
    return ans;
}

void add(int &a, int b) {
    a += b;
    if(a >= mod) a -= mod;
    if(a < 0) a += mod;
}

vector<int> primes;
int phi = 1;
int fact[MAX], inv[MAX];

void precalc() {
    map<int, int> r;
    int x = mod;
    for(int i = 2; i * i <= x; i++) {
        while(x % i == 0) {
            r[i]++;
            x /= i;
        }
    }
    if(x > 1) r[x]++;
    for(auto [p, cnt] : r) {
        int pw = p, prev = 1;
        for(int i = 1; i < cnt; i++) {
            prev = pw;
            pw *= p;
        }
        phi *= pw - prev;
        primes.pb(p);
    }
    fact[0] = inv[0] = fact[1] = inv[1] = 1;
    for(int i = 2; i < MAX; i++) {
        int x = i;
        for(auto p : primes) {
            while(x % p == 0) x /= p;
        }
        fact[i] = (fact[i - 1] * x) % mod;
        inv[i] = binpow(fact[i], phi - 1);
    }
}

map<pair<int, int>, int> ready;
map<pair<int, int>, int> used;

int C(int n, int k) {
    k = min(k, n - k);
    if(ready[{n, k}]) return used[{n, k}];
    ready[{n, k}] = 1;
    auto get = [&](int x, int p) {
        int ans = 0;
        int pw = p;
        while(pw <= x) {
            ans += x / pw;
            pw *= p;
        }
        return ans;
    };
    int ans = fact[n] * inv[k] % mod * inv[n - k] % mod;
    for(auto p : primes) {
        int cnt = get(n, p) - get(k, p) - get(n - k, p);
        ans = (ans * binpow(p, cnt)) % mod;
    }
//    cout << n << " choose " << k << " = " << ans << endl;
    return used[{n, k}] = ans;
}

void solve() {
    int n, m, k, t;
    cin >> n >> m >> k >> t >> mod; precalc();
    vector<pair<int, int>> a(k);
    for(auto &[x, y] : a) {
        cin >> x >> y;
    }
    sort(all(a));
    vector<vector<int>> get(k + 2, vector<int>(k + 2));
    for(int i = 0; i < k; i++) {
        int x = a[i].f, y = a[i].s;
        get[0][i + 1] = C(x + y, y);
        x = n - a[i].f, y = m - a[i].s;
        get[i + 1][k + 1] = C(x + y, y);
    }
    get[0][k + 1] = C(n + m, n);
    for(int i = 0; i < k; i++) {
        for(int j = i + 1; j < k; j++) {
            int x = a[j].f - a[i].f, y = a[j].s - a[i].s;
            if(x < 0 || y < 0) continue;
            get[i][j] = C(x + y, y);
        }
    }
    int ans = 0;
    vector<int> ways((1 << k));
    for(int mask = 0; mask < (1 << k); mask++) {
        vector<int> need = {0};
        for(int i = 0; i < k; i++) {
            if(mask >> i & 1) {
                need.pb(i + 1);
            }
        }
        need.pb(k + 1);
        int sz = sz(need);
        int add = 1;
        for(int i = 1; i < sz; i++) {
            add = (add * get[need[i - 1]][need[i]]) % mod;
        }
        int sign = (__builtin_popcount(mask) & 1? -1 : 1);
        ways[mask] = sign * add;
    }
    for(int i = 0; i < k; i++) {
        for(int mask = 0; mask < (1 << k); mask++) {
            if(mask >> i & 1) continue;
            add(ways[mask], ways[mask | (1 << i)]);
        }
    }
    for(int mask = 0; mask < (1 << k); mask++) {
        if(__builtin_popcount(mask) > t) continue;
        int sign = (__builtin_popcount(mask) & 1? -1 : 1);
        add(ans, sign * ways[mask]);
    }
    cout << ans << '\n';
}

signed main() {
//    freopen("turtle.in", "r", stdin); freopen("turtle.out", "w", stdout);
    ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int ttt = 1;
//    cin >> ttt;
    while(ttt--) {
        solve();
    }
    return 0;
}

Compilation message

turtle.cpp:8: warning: ignoring '#pragma comment ' [-Wunknown-pragmas]
    8 | #pragma comment (linker, "/STACK: 16777216")
      |
# 결과 실행 시간 메모리 Grader output
1 Correct 98 ms 15916 KB Output is correct
2 Incorrect 505 ms 15884 KB Output isn't correct
3 Correct 156 ms 15964 KB Output is correct
4 Correct 229 ms 16200 KB Output is correct
5 Correct 420 ms 24148 KB Output is correct
6 Incorrect 220 ms 15948 KB Output isn't correct
7 Correct 230 ms 16488 KB Output is correct
8 Correct 216 ms 15900 KB Output is correct
9 Correct 277 ms 18044 KB Output is correct
10 Correct 357 ms 20072 KB Output is correct
11 Correct 229 ms 16468 KB Output is correct
12 Correct 346 ms 20016 KB Output is correct
13 Correct 498 ms 15964 KB Output is correct
14 Correct 539 ms 16972 KB Output is correct
15 Correct 780 ms 24300 KB Output is correct
16 Incorrect 761 ms 24196 KB Output isn't correct
17 Incorrect 564 ms 17988 KB Output isn't correct
18 Incorrect 786 ms 24264 KB Output isn't correct
19 Correct 780 ms 24180 KB Output is correct
20 Correct 765 ms 24140 KB Output is correct