#include<stdio.h>
#include<math.h>
#include<algorithm>
#include<vector>
#include<map>
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pii;
ll mod[100];
ll a[100];
ll N, K, M;
vector<ll> L, X, CT;
map<ll, ll> FAC;
map<pii, ll> DB;
ll pw(ll a, ll b, ll mod)
{
ll ans = 1, mul = a;
while( b ) {
if( b % 2 == 1 ) ans = ans * mul % mod;
mul = mul * mul % mod, b /= 2;
}
return ans;
}
ll get_fact(ll N, ll mul, ll p)
{
if( N == 0 ) return 1;
ll q = N / mul, r = N % mul;
ll ans = 1, fac = 1;
ans = DB[pii(mul, r)];
fac = FAC[mul];
if( q == 0 ) return ans * get_fact(N/p, mul, p) % mul;
ans = pw(fac, q, mul) * ans % mul;
return ans * get_fact(N / p, mul, p) % mul;
}
ll gcd(ll a, ll b){
return b == 0 ? a : gcd(b, a%b);
}
pii extended_gcd(ll a, ll b){
if( b == 0 ) return pii(1, 0);
pii t = extended_gcd(b, a%b);
return pii( t.second, t.first - t.second * (a/b) );
}
ll modinverse(ll a, ll m)
{
return (extended_gcd(a, m).first % m + m ) % m;
}
ll cr(ll *a, ll *n, int size)
{
if( size == 1 ) return *a;
ll tmp = modinverse(n[0], n[1]);
ll tmp2 = (tmp * (a[1] - a[0]) % n[1] + n[1] ) % n[1];
ll ora = a[1];
ll tgcd = gcd(n[0], n[1]);
a[1] = a[0] + n[0] / tgcd * tmp2;
n[1] *= n[0] / tgcd;
ll ret = cr(a+1, n+1, size-1);
n[1] /= n[0] / tgcd;
a[1] = ora;
return ret;
}
ll fact(ll n, ll M, bool ch)
{
int sz = 0;
for(int c = 0; c < L.size(); c++){
ll m = L[c];
ll i = X[c];
ll cnt = CT[c];
a[sz] = get_fact(n, m, i);
if( !ch ){
ll su = N*2, tot = 0;
while( su ) tot += su / i, su /= i;
su = N+1;
while( su ) tot -= su / i, su /= i;
su = N;
while( su ) tot -= su / i, su /= i;
if( tot > cnt ) a[sz] = 0;
else{
for(int j = 0; j < tot; j++){
a[sz] = a[sz] * i % m;
}
}
}
mod[sz] = m;
sz++;
}
if( ch ){
for(int i = 0; i < sz; i++) a[i] = modinverse(a[i], mod[i]);
}
return cr(a, mod, sz);
}
ll Cat(ll n, ll m)
{
N = n, M = m;
return fact(2*N, M, false) * fact(N, M, true) % M * fact(N+1, M, true) % M;
}
ll DP[1000005];
int main()
{
int N;
int p, q, M, Q;
scanf("%d%d%d%d", &p, &q, &M, &Q);
if(p == 1 &&q == 0){
ll tmp = M;
for(ll i = 2; i <= tmp; i++){
if( i*i > tmp ) i = tmp;
ll m = 1, cnt = 0;
while( tmp % i == 0 ){
m *= i; cnt++;
tmp /= i;
}
if( m == 1 ) continue;
ll fac = 1;
for(ll j = 1; j <= m; j++){
DB[pii(m, j-1)] = fac;
if( j % i == 0 ) continue;
fac = fac * j % m;
}
FAC[m] = fac;
L.push_back(m);
X.push_back(i);
CT.push_back(cnt);
}
for(int i = 1; i <= Q; i++){
for(int j = 0; j < 100; j++) a[j] = mod[j] = 0;
int A;
scanf("%d", &A);
if( M == 1 ){
printf("0\n");
continue;
}
ll mul = 2, ans = 1, su = A;
while(su){
if(su%2) ans = mul * ans % M;
su /= 2; mul = mul * mul % M;
}
if (A&1) printf("%lld\n", ans);
else printf("%lld\n", (ans - Cat(A/2, M) + M)%M);
}
}
else if(p == 0 &&q == 1 ){
DP[1] = 2;
ll tot = 0, mul = 2;
for(int i = 2; i <= 1000000; i++){
mul = mul * 2 % M;
if( i%2 ) tot = tot * 2 % M;
else tot = (tot * 2 + DP[i/2]) % M;
DP[i] = (mul - tot + M*2) % M;
}
for(int t = 1; t <= Q; t++){
int A;
scanf("%d", &A);
printf("%lld\n", DP[A]);
}
}
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1450 ms |
9032 KB |
Output is correct |
2 |
Correct |
2151 ms |
9032 KB |
Output is correct |
3 |
Correct |
1157 ms |
9032 KB |
Output is correct |
4 |
Correct |
1084 ms |
9032 KB |
Output is correct |
5 |
Correct |
1240 ms |
9956 KB |
Output is correct |
6 |
Correct |
1043 ms |
71468 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
83 ms |
9032 KB |
Output is correct |
2 |
Correct |
73 ms |
9032 KB |
Output is correct |
3 |
Correct |
84 ms |
9032 KB |
Output is correct |
4 |
Correct |
88 ms |
9032 KB |
Output is correct |
5 |
Correct |
93 ms |
9032 KB |
Output is correct |
6 |
Correct |
90 ms |
9032 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
0 ms |
9028 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |