Submission #712812

#TimeUsernameProblemLanguageResultExecution timeMemory
712812ymmAsceticism (JOI18_asceticism)C++17
100 / 100
390 ms5584 KiB
#include <bits/stdc++.h> #define Loop(x,l,r) for (ll x = (l); x < (ll)(r); ++x) #define LoopR(x,l,r) for (ll x = (r)-1; x >= (ll)(l); --x) typedef long long ll; typedef std::pair<int, int> pii; typedef std::pair<ll , ll > pll; using namespace std; const int mod = 1e9+7; ll pw(ll x, ll y) { ll ans = 1; while (y) { if (y%2) ans = ans*x % mod; x = x*x % mod; y /= 2; } return ans; } ll inv(ll x) { return pw(x, mod-2); } const int N = 131072; ll fct[N], fcti[N]; void init() { fct[0] = 1; Loop (i,1,N) fct[i] = fct[i-1] * i % mod; fcti[N-1] = inv(fct[N-1]); LoopR (i,1,N) fcti[i-1] = fcti[i] * i % mod; } ll C(int n, int r) { if (r < 0 || n < r) return 0; return fct[n] * fcti[r] % mod * fcti[n-r] % mod; } typedef vector<int> poly; typedef unsigned long long ull; #define MOD(x) ((x) >= mod? (x) - mod: x) #define SMOD(x) ((x) = MOD(x)) __attribute__((optimize("O3,unroll-loops"),target("avx"))) void karatsuba(int *__restrict__ a, int *__restrict__ b, int *__restrict__ ans, int *__restrict__ tmp, int n) { if (n <= 16) { static ull ansl[32]; Loop (i,0,2*n-1) ansl[i] = 0; Loop (i,0,n) Loop (j,0,n) ansl[i+j] += (ull)a[i] * b[j]; Loop (i,0,2*n-1) ans[i] = ansl[i] % mod; ans[2*n-1] = 0; return; } Loop (i,0,n) { tmp[i] = a[i]; tmp[i+n] = b[i]; } Loop (i,0,n/2) a[i+n/2] = b[i]; karatsuba(a, a+n/2, ans, b, n/2); Loop (i,0,n/2) { a[i] = tmp[i+n/2]; a[i+n/2] = tmp[i+n+n/2]; } karatsuba(a, a+n/2, ans+n, b, n/2); Loop (i,0,n/2) { a[i] = MOD(tmp[i] + tmp[i+n/2]); a[i+n/2] = MOD(tmp[i+n] + tmp[i+n+n/2]); } karatsuba(a, a+n/2, b, tmp, n/2); Loop (i,0,n) b[i] = MOD(MOD(b[i] + mod - ans[i]) + mod - ans[i+n]); Loop (i,0,n) ans[i+n/2] = MOD(ans[i+n/2] + b[i]); } vector<int> stir; void init_stir(int m) { poly A(N), B(N); Loop (i,0,N) { A[i] = pw(i, m) * fcti[i] % mod; B[i] = i%2? mod-fcti[i]: fcti[i]; } stir.resize(2*N); vector<int> tmp(2*N); karatsuba(&A[0], &B[0], &stir[0], &tmp[0], N); stir.resize(N); Loop (i,0,N) stir[i] = fct[i] * stir[i] % mod; } int main() { cin.tie(0) -> sync_with_stdio(false); ll a, b; cin >> b >> a; init(); init_stir(b); ll ans = 0; Loop (x,0,a) { ll dard = stir[a-x] * C(b-a+x, x) % mod; ans += x%2? -dard: dard; } ans = (ans%mod+mod) % mod; cout << ans << '\n'; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...