#include <bits/stdc++.h>
#define Loop(x,l,r) for (ll x = (l); x < (ll)(r); ++x)
#define LoopR(x,l,r) for (ll x = (r)-1; x >= (ll)(l); --x)
typedef long long ll;
typedef std::pair<int, int> pii;
typedef std::pair<ll , ll > pll;
using namespace std;
const int mod = 1e9+7;
ll pw(ll x, ll y)
{
ll ans = 1;
while (y) {
if (y%2)
ans = ans*x % mod;
x = x*x % mod;
y /= 2;
}
return ans;
}
ll inv(ll x) { return pw(x, mod-2); }
const int N = 131072;
ll fct[N], fcti[N];
void init()
{
fct[0] = 1;
Loop (i,1,N)
fct[i] = fct[i-1] * i % mod;
fcti[N-1] = inv(fct[N-1]);
LoopR (i,1,N)
fcti[i-1] = fcti[i] * i % mod;
}
ll C(int n, int r)
{
if (r < 0 || n < r)
return 0;
return fct[n] * fcti[r] % mod * fcti[n-r] % mod;
}
typedef vector<int> poly;
typedef unsigned long long ull;
#define MOD(x) ((x) >= mod? (x) - mod: x)
#define SMOD(x) ((x) = MOD(x))
void karatsuba(int *__restrict__ a, int *__restrict__ b, int *__restrict__ ans, int *__restrict__ tmp, int n)
{
if (n <= 16) {
static ull ansl[32];
Loop (i,0,2*n-1)
ansl[i] = 0;
Loop (i,0,n) Loop (j,0,n)
ansl[i+j] += (ull)a[i] * b[j];
Loop (i,0,2*n-1)
ans[i] = ansl[i] % mod;
ans[2*n-1] = 0;
return;
}
Loop (i,0,n) {
tmp[i] = a[i];
tmp[i+n] = b[i];
}
Loop (i,0,n/2)
a[i+n/2] = b[i];
karatsuba(a, a+n/2, ans, b, n/2);
Loop (i,0,n/2) {
a[i] = tmp[i+n/2];
a[i+n/2] = tmp[i+n+n/2];
}
karatsuba(a, a+n/2, ans+n, b, n/2);
Loop (i,0,n/2) {
a[i] = MOD(tmp[i] + tmp[i+n/2]);
a[i+n/2] = MOD(tmp[i+n] + tmp[i+n+n/2]);
}
karatsuba(a, a+n/2, b, tmp, n/2);
Loop (i,0,n)
b[i] = MOD(MOD(b[i] + mod - ans[i]) + mod - ans[i+n]);
Loop (i,0,n)
ans[i+n/2] = MOD(ans[i+n/2] + b[i]);
}
vector<int> stir;
void init_stir(int m)
{
poly A(N), B(N);
Loop (i,0,N) {
A[i] = pw(i, m) * fcti[i] % mod;
B[i] = i%2? mod-fcti[i]: fcti[i];
}
stir.resize(2*N);
vector<int> tmp(2*N);
karatsuba(&A[0], &B[0], &stir[0], &tmp[0], N);
stir.resize(N);
Loop (i,0,N)
stir[i] = fct[i] * stir[i] % mod;
}
int main()
{
cin.tie(0) -> sync_with_stdio(false);
ll a, b;
cin >> b >> a;
init();
init_stir(b);
ll ans = 0;
Loop (x,0,a) {
ll dard = stir[a-x] * C(b-a+x, x) % mod;
ans += x%2? -dard: dard;
}
ans = (ans%mod+mod) % mod;
cout << ans << '\n';
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
700 ms |
5580 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
700 ms |
5580 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
700 ms |
5580 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
700 ms |
5580 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |