Submission #637827

# Submission time Handle Problem Language Result Execution time Memory
637827 2022-09-03T11:49:47 Z TAMREF 흑백 이미지 찾기 (kriii3_G) C++17
0 / 101
3557 ms 113044 KB
#include <bits/stdc++.h>
#define va first
#define vb second
#define lb lower_bound
#define ub upper_bound
#define bs binary_search
#define pp push_back
#define ep emplace_back
#define all(v) (v).begin(),(v).end()
#define szz(v) ((int)(v).size())
#define bi_pc __builtin_popcount
#define bi_pcll __builtin_popcountll
#define bi_tz __builtin_ctz
#define bi_tzll __builtin_ctzll
#define fio ios_base::sync_with_stdio(0);cin.tie(0);
#ifdef TAMREF
#define debug(...) fprintf(stderr, __VA_ARGS__)
#else
#define debug(...) 42
#endif
using namespace std;
using ll = long long; using lf = long double; 
using pii = pair<int,int>; using ppi = pair<int,pii>;
using pll = pair<ll,ll>; using pff = pair<lf,lf>;
using ti = tuple<int,int,int>;
using base = complex<double>;
const lf PI = 3.14159265358979323846264338L;
template <typename T>
inline T umax(T& u, T v){return u = max(u, v);}
template <typename T>
inline T umin(T& u, T v){return u = min(u, v);}
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
template<class> struct is_container : false_type {};
template<class... Ts> struct is_container<vector<Ts...>> : true_type {};
template<class... Ts> struct is_container<map<Ts...>> : true_type {};
template<class... Ts> struct is_container<unordered_map<Ts...>> : true_type {};
template<class... Ts> struct is_container<set<Ts...>> : true_type {};
template<class... Ts> struct is_container<unordered_set<Ts...>> : true_type {};
template<class... Ts> struct is_container<multiset<Ts...>> : true_type {};
template<class... Ts> struct is_container<unordered_multiset<Ts...>> : true_type {};
template<class T, class = typename enable_if<is_container<T>::value>::type>
ostream& operator<<(ostream &o, T x) {
  #ifndef TAMREF
  return o;
  #endif
  int f = 1;
  o << "{";
  for (auto y : x) {
    o << (f ? "" : ", ") << y;
    f = 0;
  }
  return o << "}\n";
}
template<class T, class U>
ostream& operator<<(ostream &o, pair<T, U> x) {
  #ifndef TAMREF
  return o;
  #endif
  return o << "(" << x.first << ", " << x.second << ")";
}

template<ll mod>
struct NTT{
  using ll = long long;
      inline ll pw(ll x, ll n) {
      ll ret = 1, pv = x;
      for(; n; n >>= 1){
          if(n & 1) ret = ret * pv % mod;
          pv = pv * pv % mod;
      }
      return ret;
  }

  inline ll multInv(ll x){
      return pw(x, mod - 2);
  }
  void ntt(vector<ll>& a, bool inv){
      const int w = 3; //for mod = 998244353
      int n = szz(a), j = 0;
      vector<ll> roots(n/2);
      for(int i = 1; i < n; i++){
          int b = (n >> 1);
          while(j >= b) {
              j -= b; b >>= 1;
          }
          j += b;
          if(i < j) swap(a[i], a[j]);
      }
      ll ang = pw(w, (mod - 1) / n);
      if(inv) ang = pw(ang, mod - 2);
      for(int i = 0; i < n / 2; i++){
          roots[i] = (i ? roots[i-1] * ang % mod : 1);
      }
      for(int i = 2; i <= n; i <<= 1){
          int step = n / i;
          for(int j = 0; j < n; j += i){
              for(int k = 0; k < i / 2; k++){
                  ll u = a[j + k], v = a[j + k + i/2] * roots[step * k] % mod;
                  a[j + k] = (u + v) % mod;
                  a[j + k + i/2] = (u + mod - v) % mod;
              }
          }
      }
      ll ninv = multInv(n);
      if(inv) for(int i = 0; i < n; i++) a[i] = a[i] * ninv % mod;
  }

  void ntt_row(vector<vector<ll>> &a, bool inv) {
    const int w = 3; //for mod = 998244353
    int n = szz(a), j = 0, m = szz(a[0]);
    vector<ll> roots(n/2);
    for(int i = 1; i < n; i++){
        int b = (n >> 1);
        while(j >= b) {
            j -= b; b >>= 1;
        }
        j += b;
        if(i < j) a[i].swap(a[j]);
    }
    ll ang = pw(w, (mod - 1) / n);
    if(inv) ang = pw(ang, mod - 2);
    for(int i = 0; i < n / 2; i++){
        roots[i] = (i ? roots[i-1] * ang % mod : 1);
    }
    for(int i = 2; i <= n; i <<= 1){
        int step = n / i;
        for(int j = 0; j < n; j += i){
            for(int k = 0; k < i / 2; k++){
                for(int h = 0; h < m; h++) {
                  ll u = a[j + k][h], v = a[j + k + i/2][h] * roots[step * k] % mod;
                  a[j + k][h] = (u + v) % mod;
                  a[j + k + i/2][h] = (u + mod - v) % mod;
                }
            }
        }
    }
    ll ninv = multInv(m);
    if(inv) for(int i = 0; i < n; i++) for(int j = 0; j < m; j++) a[i][j] = a[i][j] * ninv % mod;
  }

  void ntt2d(vector<vector<ll>> &a, bool inv) {
    int n = szz(a);
    for(int i = 0; i < n; i++) ntt(a[i], inv);
    ntt_row(a, inv);
  }

  vector<vector<ll>> mult2d(vector<vector<ll>> a, vector<vector<ll>> b) {
    int n = 2; while(n < szz(a) + szz(b)) n <<= 1;
    int m = 2; while(m < szz(a[0]) + szz(b[0])) m <<= 1;
    a.resize(n); b.resize(n);
    for(int i = 0; i < n; i++) a[i].resize(m);
    for(int i = 0; i < n; i++) b[i].resize(m);
    ntt2d(a, false); ntt2d(b, false);
    for(int i = 0; i < n; i++) for(int j = 0; j < m; j++) a[i][j] = a[i][j] * b[i][j] % mod;
    ntt2d(a, true);
    return a;
  }

  vector<ll> mult(vector<ll> &a, vector<ll> &b){
      vector<ll> fa(all(a)), fb(all(b));
      int n = 2; while(n < szz(a) + szz(b)) n <<= 1;
      fa.resize(n); fb.resize(n);
      ntt(fa, false); ntt(fb, false);
      for(int i = 0; i < n; i++) fa[i] = fa[i] * fb[i] % mod;
      ntt(fa, true);
      return fa;
  }
};

constexpr ll mod1 = 998244353, mod2 = 7 << 26 | 1;
NTT<mod1> n1;
NTT<mod2> n2; 
vector<vector<ll>> a, b;
int n, m, r, c;
const ll tmods[4] = {1000000007, 1000000009, 998244353, 7 << 26 | 1};
using L = __int128_t;

void merge_by_crt(vector<vector<ll>> &a, vector<vector<ll>> &b) {
  ll inv1 = n2.multInv(mod1), inv2 = n1.multInv(mod2), pmod = mod1 * mod2;
  for(int i = 0; i < szz(a); i++) {
    for(int j = 0; j < szz(b); j++) {
      a[i][j] = (
        __int128_t(a[i][j]) * mod2 * inv2 + __int128_t(b[i][j]) * mod1 * inv1
      ) % pmod;
    }
  }
}

int main(){
  fio;
  cin >> n >> m;
  a = vector<vector<ll>>(n, vector<ll>(m));
  for(int i = 0; i < n; i++) for(int j = 0; j < m; j++) cin >> a[i][j];
  cin >> r >> c;
  b = vector<vector<ll>>(r, vector<ll>(c));
  for(int i = 0; i < r; i++) for(int j = 0; j < c; j++) cin >> b[r-1-i][c-1-j];
  auto ab = n1.mult2d(a, b);
  auto ab2 = n2.mult2d(a, b);
  merge_by_crt(ab, ab2);

  #ifdef TAMREF
  cerr << ab << '\n';
  #endif
  __int128_t asum = 0, a2sum = 0, bsum = 0, b2sum = 0;

  for(int i = 0; i < r; i++) {
    for(int j = 0; j < c; j++) {
      bsum += b[i][j];
      b2sum += b[i][j] * b[i][j];
    }
  }

  ll ans = 0;


  auto compare = [&](int i, int j) {
    for(const ll mod : tmods) {
      auto pw = [&](ll x, ll n) -> ll {
        ll r = 1, p = x;
        for(;n;n>>=1) {
          if(n&1) r = r * p % mod;
          p = p * p % mod;
        }
        return r;
      };
      auto multInv = [&](ll x) -> ll {
        return pw(x, mod - 2);
      };
      auto normalize = [&](ll x) -> ll {
        x += mod;
        return x >= mod ? x - mod : x;
      };
      ll _a = asum % mod, _b = bsum % mod, _a2 = a2sum % mod, _b2 = b2sum % mod;
      ll rcinv = multInv(r * c), rcinv2 = rcinv * rcinv % mod;
      ll bvar = normalize((_b2 * rcinv - _b * _b % mod * rcinv2) % mod);
      ll cov = normalize((ab[i + r - 1][j + c - 1] % mod * rcinv - _a * _b % mod * rcinv2) % mod);
      ll avar = normalize((_a2 * rcinv - _a * _a % mod * rcinv2) % mod);
      ll diff = normalize((cov * cov - avar * bvar) % mod);
      if(diff) {
        debug("i = %d, j = %d, mod = %lld, diff = %lld\n", i, j, mod, diff);
      }
      if(diff) return;
    }
    ++ans;
    debug("ans = %lld\n", ans);
  };

  for(int k = 0; k + r - 1 < n; k++) {
    asum = a2sum = 0;
    for(int i = 0; i < r; i++) {
      for(int j = 0; j < c; j++) {
        asum += a[k+i][j];
        a2sum += a[k+i][j] * a[k+i][j];
      }
    }

    compare(k, 0);
    for(int j = 1; j + c - 1 < m; j++) {
      for(int i = 0; i < r; i++) {
        asum += a[k+i][j+c-1] - a[k+i][j-1];
        a2sum += a[k+i][j+c-1] * a[k+i][j+c-1] - a[k+i][j-1] * a[k+i][j-1];
      }
      compare(k, j);
    }
  }

  cout << ans << '\n';
}

Compilation message

G.cpp: In lambda function:
G.cpp:19:20: warning: statement has no effect [-Wunused-value]
   19 | #define debug(...) 42
      |                    ^~
G.cpp:240:9: note: in expansion of macro 'debug'
  240 |         debug("i = %d, j = %d, mod = %lld, diff = %lld\n", i, j, mod, diff);
      |         ^~~~~
G.cpp:19:20: warning: statement has no effect [-Wunused-value]
   19 | #define debug(...) 42
      |                    ^~
G.cpp:245:5: note: in expansion of macro 'debug'
  245 |     debug("ans = %lld\n", ans);
      |     ^~~~~
# Verdict Execution time Memory Grader output
1 Correct 17 ms 980 KB Output is correct
2 Correct 11 ms 980 KB Output is correct
3 Incorrect 11 ms 996 KB Output isn't correct
4 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Correct 2083 ms 112276 KB Output is correct
2 Correct 2159 ms 112316 KB Output is correct
3 Correct 2234 ms 112428 KB Output is correct
4 Correct 2847 ms 112176 KB Output is correct
5 Correct 2717 ms 112132 KB Output is correct
6 Correct 2176 ms 109068 KB Output is correct
7 Correct 2302 ms 109028 KB Output is correct
8 Correct 2316 ms 109016 KB Output is correct
9 Correct 2372 ms 112712 KB Output is correct
10 Correct 2331 ms 112748 KB Output is correct
11 Correct 3557 ms 113004 KB Output is correct
12 Correct 3536 ms 111052 KB Output is correct
13 Correct 3551 ms 113044 KB Output is correct
14 Incorrect 2633 ms 112992 KB Output isn't correct
15 Halted 0 ms 0 KB -