Submission #777537

#TimeUsernameProblemLanguageResultExecution timeMemory
777537qwerasdfzxclFestivals in JOI Kingdom 2 (JOI23_festival2)C++17
100 / 100
4522 ms8320 KiB
#include <bits/stdc++.h> using namespace std; typedef long long ll; int n, MOD; ll dpNAIVE[100100][2]; ll P[9][100100], Q[9][100100], R[9][100100], dp[100100][2]; ll fact[100100], factINV[100100]; inline void add(ll &x, const ll &y){ x += y; if (x>=MOD) x -= MOD; } inline ll mul(const vector<ll> &a){ ll ret = 1; for (auto &x:a) ret = ret * x % MOD; return ret; } inline ll mul(ll x, ll y){ return x * y % MOD; } vector<ll> operator +(const vector<ll> &A, const vector<ll> &B){ vector<ll> ret(max(A.size(), B.size())); for (int i=0;i<(int)ret.size();i++){ if (i<(int)A.size()) add(ret[i], A[i]); if (i<(int)B.size()) add(ret[i], B[i]); } return ret; } vector<ll> cut(ll a[], int l, int r){ vector<ll> ret(r-l+1); for (int i=0;i<(int)ret.size();i++) ret[i] = a[l+i]; return ret; } vector<ll> cut(const vector<ll> &a, int l, int r){ vector<ll> ret(r-l+1); for (int i=0;i<(int)ret.size();i++) ret[i] = a[l+i]; return ret; } ll pw(ll a, ll e){ if (!e) return 1; ll ret = pw(a, e/2); if (e&1) return ret * ret % MOD * a % MOD; return ret * ret % MOD; } ll all(){ ll ret = 1; for (int i=1;i<=n;i++) ret = ret * (i*2-1) % MOD; return ret; } void init(){ fact[0] = 1; for (int i=1;i<=n*3+100;i++) fact[i] = fact[i-1] * i % MOD; factINV[n*3+100] = pw(fact[n*3+100], MOD-2); for (int i=n*3+99;i>=0;i--) factINV[i] = factINV[i+1] * (i+1) % MOD; } void naive(){ dpNAIVE[2][0] = 1; dpNAIVE[1][1] = 1; for (int i=1;i<=n;i++){ for (int j=0;i+j<=n;j++){ if (i>1) add(dpNAIVE[i+j+2][0], mul({dpNAIVE[i][0], j+2, j+1, fact[2*i-2 + (j-1)], factINV[2*i-3]})); if (i>0) add(dpNAIVE[i+j+2][0], mul({dpNAIVE[i][1], j+1, fact[2*i-1 + (j-1)], factINV[2*i-2]})); if (i>1) add(dpNAIVE[i+j+1][1], mul({dpNAIVE[i][0], j+1, fact[2*i-2 + (j-1)], factINV[2*i-3]})); if (i>0) add(dpNAIVE[i+j+1][1], mul({dpNAIVE[i][1], fact[2*i-1 + (j-1)], factINV[2*i-2]})); } } } vector<ll> naive(const vector<ll> &A, const vector<ll> &B){ vector<ll> ret(A.size() + B.size() - 1); for (int i=0;i<(int)A.size();i++){ for (int j=0;j<(int)B.size();j++){ add(ret[i+j], A[i] * B[j] % MOD); } } return ret; } vector<ll> karatsuba(const vector<ll> &A, const vector<ll> &B){ if (A.size() <= 20) return naive(A, B); int mid = (int)A.size() / 2; auto A2 = cut(A, 0, mid-1), A1 = cut(A, mid, (int)A.size()-1); auto B2 = cut(B, 0, mid-1), B1 = cut(B, mid, (int)B.size()-1); auto C1 = karatsuba(A1, B1), C2 = karatsuba(A2, B2), C3 = karatsuba(A1+A2, B1+B2); C2.resize(C1.size()); assert(C1.size() == C2.size()); assert(C1.size() == C3.size()); vector<ll> ret(A.size()+B.size()-1); for (int i=0;i<(int)C1.size();i++) add(ret[i+mid*2], C1[i]); for (int i=0;i<(int)C2.size();i++) add(ret[i], C2[i]); for (int i=0;i<(int)C3.size();i++) add(ret[i+mid], (C3[i] - C1[i] - C2[i] + (ll)MOD*2) % MOD); return ret; } pair<vector<ll>, int> myConv(vector<ll> A, vector<ll> B){ int ofs = (int)A.size()-1; int sz = max(A.size(), B.size()); assert(A.size() <= B.size()); reverse(A.begin(), A.end()); A.resize(sz); return {karatsuba(A, B), ofs}; } void dnc(int l, int r){ if (l==r){ if (l==1){ dp[1][0] = 0; dp[1][1] = 1; } else if (l==2){ dp[2][0] = 1; dp[2][1] = 1; } else{ add(dp[l][0], R[1][l]); add(dp[l][0], mul(R[2][l], (MOD-l*2+1))); add(dp[l][0], mul(R[3][l], mul(l, l)-l+MOD)); add(dp[l][0], R[4][l]); add(dp[l][0], mul(R[5][l], l-1)); add(dp[l][1], R[6][l]); add(dp[l][1], mul(R[7][l], l)); add(dp[l][1], R[8][l]); } if (l>=2){ P[1][l] = mul({dp[l][0], l, l, factINV[2*l-3]}); P[2][l] = mul({dp[l][0], l, factINV[2*l-3]}); P[3][l] = mul(dp[l][0], factINV[2*l-3]); P[6][l] = mul({dp[l][0], MOD-l, factINV[2*l-3]}); P[7][l] = mul(dp[l][0], factINV[2*l-3]); } if (l>=1){ P[4][l] = mul({dp[l][1], MOD-l, factINV[2*l-2]}); P[5][l] = mul(dp[l][1], factINV[2*l-2]); P[8][l] = mul(dp[l][1], factINV[2*l-2]); } return; } int m = (l+r)>>1; dnc(l, m); for (int k=1;k<=8;k++){ auto [ret, ofs] = myConv(cut(P[k], l, m), cut(Q[k], m+1+l, m+r)); for (int i=m+1;i<=r;i++) add(R[k][i], ret[ofs+(i-m-1)]); } dnc(m+1, r); } void solve(){ for (int i=0;i<=n*2+100;i++){ if (i>=5) Q[1][i] = Q[2][i] = Q[3][i] = fact[i-5]; if (i>=4) Q[4][i] = Q[5][i] = fact[i-4]; if (i>=4) Q[6][i] = Q[7][i] = fact[i-4]; if (i>=3) Q[8][i] = fact[i-3]; } dnc(1, n+1); } int main(){ scanf("%d %d", &n, &MOD); init(); solve(); printf("%lld\n", (all() + MOD - dp[n+1][1]) % MOD); }

Compilation message (stderr)

festival2.cpp: In function 'int main()':
festival2.cpp:195:7: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  195 |  scanf("%d %d", &n, &MOD);
      |  ~~~~~^~~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...