This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <memory.h>
#include <math.h>
#include <assert.h>
#include <stack>
#include <queue>
#include <map>
#include <set>
#include <string>
#include <algorithm>
#include <iostream>
#include <functional>
#include <unordered_set>
#include <bitset>
#include <time.h>
#include <limits.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
#define Fi first
#define Se second
#define pb push_back
#define szz(x) (int)x.size()
#define rep(i,n) for(int i=0;i<n;i++)
#define all(x) x.begin(),x.end()
typedef tuple<int, int, int> t3;
#include <complex>
typedef long double ldouble;
namespace FFT{
// blog.myungwoo.kr/54
typedef complex<ldouble> base;
typedef long long ll;
#define sz(x) ((int)(x).size())
const ldouble C_PI = acosl(-1);
void fft(vector <base> &a, bool invert){
int n = sz(a);
for(int i=0,j=0;i<n;++i) {
if(i>j) swap(a[i],a[j]);
for(int k=n>>1;(j^=k)<k;k>>=1);
}
for (int len=2;len<=n;len<<=1){
ldouble ang = 2*C_PI/len*(invert?-1:1);
base wlen(cosl(ang), sinl(ang));
for (int i=0;i<n;i+=len){
base w(1);
for (int j=0;j<len/2;j++){
if((j & 31) == 31)w = base(cosl(ang * j), sinl(ang * j)); //오차가 클 경우 이 빈도를 늘린다. cos, sin 함수는 시간 부담이 있으니 주의
base u = a[i+j], v = a[i+j+len/2]*w;
a[i+j] = u+v;
a[i+j+len/2] = u-v;
w *= wlen;
}
}
}
if (invert){
for (int i=0;i<n;i++) a[i] /= n;
}
}
void multiply(const vector<int> &a,const vector<int> &b,vector<int> &res, const int MOD){
vector <base> fa(all(a)), fb(all(b));
int n = 1;
while (n < max(sz(a),sz(b))) n <<= 1; n <<= 1;
fa.resize(n); fb.resize(n);
fft(fa,false); fft(fb,false);
for (int i=0;i<n;i++) fa[i] *= fb[i];
fft(fa,true);
res.resize(n);
for (int i=0;i<n;i++) res[i] = ((ll)(fa[i].real()+(fa[i].real()>0?0.5:-0.5))) % MOD;
}
void multiply_big(const vector<int> &a,const vector<int> &b, vector <ll> &res){
// 단순히 오차가 심해 구하지 못하는 경우
// 결과값은 long long 범위 안
int n = 1;
while (n < max(sz(a),sz(b))) n <<= 1; n <<= 1;
vector <base> A(n), B(n);
int L_BLOCK = 10;
for(int i=0;i<n;i++) A[i] = (i < sz(a) ? base(a[i] & ((1<<L_BLOCK)-1), a[i] >> L_BLOCK) : base(0));
for(int i=0;i<n;i++) B[i] = (i < sz(b) ? base(b[i] & ((1<<L_BLOCK)-1), b[i] >> L_BLOCK) : base(0));
fft(A, false); fft(B, false);
vector <base> f1(n), f2(n), f3(n), f4(n);
for(int i=0;i<n;i++) {
int j=(n-i)&(n-1);
f2[i]=(A[i]+conj(A[j]))*base(0.5,0);
f1[i]=(A[i]-conj(A[j]))*base(0,-0.5);
f4[i]=(B[i]+conj(B[j]))*base(0.5,0);
f3[i]=(B[i]-conj(B[j]))*base(0,-0.5);
}
for(int i=0;i<n;i++) {
A[i]=f1[i]*f3[i]+f1[i]*f4[i]*base(0,1);
B[i]=f2[i]*f4[i]*base(0,1)+f2[i]*f3[i];
}
fft(A, true); fft(B, true);
res.resize(n);
for(int i=0;i<n;i++) {
ll g1=(ll)(A[i].real()+0.5);
ll g2=(ll)(A[i].imag()+0.5);
ll g3=(ll)(B[i].real()+0.5);
ll g4=(ll)(B[i].imag()+0.5);
res[i] = (g4 + ((g2+g3)<<(L_BLOCK)) + (g1<<(L_BLOCK<<1)));
}
}
}
int N, M, R, C;
int A[1010][1010], B[1010][1010];
ll Sx[1010][1010], Sxx[1010][1010];
vector <int> va, vb;
vector <ll> res;
ll get_s(ll T[1010][1010], int x1, int y1, int x2, int y2) {
return T[x2][y2] - T[x2][y1-1] - T[x1-1][y2] + T[x1-1][y1-1];
}
#define i128 __int128
int main() {
scanf("%d%d", &N, &M);
for(int i=1;i<=N;i++) for(int j=1;j<=M;j++) scanf("%d", A[i] + j);
for(int i=1;i<=N;i++) for(int j=1;j<=M;j++) {
Sx[i][j] = Sx[i-1][j] + Sx[i][j-1] - Sx[i-1][j-1] + A[i][j];
Sxx[i][j] = Sxx[i-1][j] + Sxx[i][j-1] - Sxx[i-1][j-1] + (ll)A[i][j] * A[i][j];
}
scanf("%d%d", &R, &C);
for(int i=1;i<=R;i++) for(int j=1;j<=C;j++) scanf("%d", B[i] + j);
ll sy = 0, syy = 0;
for(int i=1;i<=R;i++) for(int j=1;j<=C;j++) {
sy += B[i][j];
syy += (ll) B[i][j] * B[i][j];
}
for(int i=1;i<=N;i++) for(int j=1;j<=M;j++) va.pb(A[i][j]);
for(int i=R;i;i--) {
for(int j=C;j;j--) vb.pb(B[i][j]);
rep(j, M - C) vb.pb(0);
}
FFT::multiply_big(va, vb, res);
int ans = 0;
ll prs[3] = {1000000000000000003ll, 1000000000000000009ll, 1000000000000000031ll};
for(int i=R;i<=N;i++) for(int j=C;j<=M;j++) {
ll sxy = res[(i-1)*M+(j-1)];
ll sx = get_s(Sx, i-R+1, j-C+1, i, j);
ll sxx = get_s(Sxx, i-R+1, j-C+1, i, j);
ll n = R * C;
int ok = 1;
rep(u, 3) {
ll mod = prs[u];
i128 val = (i128)syy * (((i128)n * sxx - (i128)sx * sx) % mod) % mod;
val = (val - (i128) sxy * sxy % mod * n) % mod; if(val < 0) val += mod;
val = (val - (i128) sxx * sy % mod * sy) % mod; if(val < 0) val += mod;
val = (val + (i128) 2 * sxy * sx % mod * sy) % mod;
if(val != 0) ok = 0;
}
ans += ok;
}
printf("%d\n", ans);
return 0;
}
Compilation message (stderr)
G.cpp: In function 'void FFT::multiply(const std::vector<int>&, const std::vector<int>&, std::vector<int>&, int)':
G.cpp:71:3: warning: this 'while' clause does not guard... [-Wmisleading-indentation]
while (n < max(sz(a),sz(b))) n <<= 1; n <<= 1;
^~~~~
G.cpp:71:41: note: ...this statement, but the latter is misleadingly indented as if it were guarded by the 'while'
while (n < max(sz(a),sz(b))) n <<= 1; n <<= 1;
^
G.cpp: In function 'void FFT::multiply_big(const std::vector<int>&, const std::vector<int>&, std::vector<long long int>&)':
G.cpp:84:3: warning: this 'while' clause does not guard... [-Wmisleading-indentation]
while (n < max(sz(a),sz(b))) n <<= 1; n <<= 1;
^~~~~
G.cpp:84:41: note: ...this statement, but the latter is misleadingly indented as if it were guarded by the 'while'
while (n < max(sz(a),sz(b))) n <<= 1; n <<= 1;
^
G.cpp: In function 'int main()':
G.cpp:130:7: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
scanf("%d%d", &N, &M);
~~~~~^~~~~~~~~~~~~~~~
G.cpp:131:51: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
for(int i=1;i<=N;i++) for(int j=1;j<=M;j++) scanf("%d", A[i] + j);
~~~~~^~~~~~~~~~~~~~~~
G.cpp:136:7: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
scanf("%d%d", &R, &C);
~~~~~^~~~~~~~~~~~~~~~
G.cpp:137:51: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
for(int i=1;i<=R;i++) for(int j=1;j<=C;j++) scanf("%d", B[i] + j);
~~~~~^~~~~~~~~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |