답안 #637865

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
637865 2022-09-03T12:36:02 Z TAMREF 흑백 이미지 찾기 (kriii3_G) C++17
101 / 101
2327 ms 126552 KB
#include <bits/stdc++.h>
#define va first
#define vb second
#define lb lower_bound
#define ub upper_bound
#define bs binary_search
#define pp push_back
#define ep emplace_back
#define all(v) (v).begin(),(v).end()
#define szz(v) ((int)(v).size())
#define bi_pc __builtin_popcount
#define bi_pcll __builtin_popcountll
#define bi_tz __builtin_ctz
#define bi_tzll __builtin_ctzll
#define fio ios_base::sync_with_stdio(0);cin.tie(0);
#ifdef TAMREF
#define debug(...) fprintf(stderr, __VA_ARGS__)
#else
#define debug(...) 42
#endif
using namespace std;
using ll = long long; using lf = long double; 
using pii = pair<int,int>; using ppi = pair<int,pii>;
using pll = pair<ll,ll>; using pff = pair<lf,lf>;
using ti = tuple<int,int,int>;
using base = complex<double>;
const lf PI = 3.14159265358979323846264338L;
template <typename T>
inline T umax(T& u, T v){return u = max(u, v);}
template <typename T>
inline T umin(T& u, T v){return u = min(u, v);}
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
template<class> struct is_container : false_type {};
template<class... Ts> struct is_container<vector<Ts...>> : true_type {};
template<class... Ts> struct is_container<map<Ts...>> : true_type {};
template<class... Ts> struct is_container<unordered_map<Ts...>> : true_type {};
template<class... Ts> struct is_container<set<Ts...>> : true_type {};
template<class... Ts> struct is_container<unordered_set<Ts...>> : true_type {};
template<class... Ts> struct is_container<multiset<Ts...>> : true_type {};
template<class... Ts> struct is_container<unordered_multiset<Ts...>> : true_type {};
template<class T, class = typename enable_if<is_container<T>::value>::type>
ostream& operator<<(ostream &o, T x) {
  #ifndef TAMREF
  return o;
  #endif
  int f = 1;
  o << "{";
  for (auto y : x) {
    o << (f ? "" : ", ") << y;
    f = 0;
  }
  return o << "}\n";
}
template<class T, class U>
ostream& operator<<(ostream &o, pair<T, U> x) {
  #ifndef TAMREF
  return o;
  #endif
  return o << "(" << x.first << ", " << x.second << ")";
}

template<ll mod>
struct NTT{
  using ll = long long;
      inline ll pw(ll x, ll n) {
      ll ret = 1, pv = x;
      for(; n; n >>= 1){
          if(n & 1) ret = ret * pv % mod;
          pv = pv * pv % mod;
      }
      return ret;
  }

  inline ll multInv(ll x){
      return pw(x, mod - 2);
  }
  void ntt(vector<ll>& a, bool inv){
      const int w = 3; //for mod = 998244353
      int n = szz(a), j = 0;
      vector<ll> roots(n/2);
      for(int i = 1; i < n; i++){
          int b = (n >> 1);
          while(j >= b) {
              j -= b; b >>= 1;
          }
          j += b;
          if(i < j) swap(a[i], a[j]);
      }
      ll ang = pw(w, (mod - 1) / n);
      if(inv) ang = pw(ang, mod - 2);
      for(int i = 0; i < n / 2; i++){
          roots[i] = (i ? roots[i-1] * ang % mod : 1);
      }
      for(int i = 2; i <= n; i <<= 1){
          int step = n / i;
          for(int j = 0; j < n; j += i){
              for(int k = 0; k < i / 2; k++){
                  ll u = a[j + k], v = a[j + k + i/2] * roots[step * k] % mod;
                  a[j + k] = (u + v) % mod;
                  a[j + k + i/2] = (u + mod - v) % mod;
              }
          }
      }
      ll ninv = multInv(n);
      if(inv) for(int i = 0; i < n; i++) a[i] = a[i] * ninv % mod;
  }

  void ntt_row(vector<vector<ll>> &a, bool inv) {
    const int w = 3; //for mod = 998244353
    int n = szz(a), j = 0, m = szz(a[0]);
    vector<ll> roots(n/2);
    for(int i = 1; i < n; i++){
        int b = (n >> 1);
        while(j >= b) {
            j -= b; b >>= 1;
        }
        j += b;
        if(i < j) a[i].swap(a[j]);
    }
    ll ang = pw(w, (mod - 1) / n);
    if(inv) ang = pw(ang, mod - 2);
    for(int i = 0; i < n / 2; i++){
        roots[i] = (i ? roots[i-1] * ang % mod : 1);
    }
    for(int i = 2; i <= n; i <<= 1){
        int step = n / i;
        for(int j = 0; j < n; j += i){
            for(int k = 0; k < i / 2; k++){
                for(int h = 0; h < m; h++) {
                  ll u = a[j + k][h], v = a[j + k + i/2][h] * roots[step * k] % mod;
                  a[j + k][h] = (u + v) % mod;
                  a[j + k + i/2][h] = (u + mod - v) % mod;
                }
            }
        }
    }
    ll ninv = multInv(m);
    if(inv) for(int i = 0; i < n; i++) for(int j = 0; j < m; j++) a[i][j] = a[i][j] * ninv % mod;
  }

  void ntt2d(vector<vector<ll>> &a, bool inv) {
    int n = szz(a);
    for(int i = 0; i < n; i++) ntt(a[i], inv);
    ntt_row(a, inv);
  }

  vector<vector<ll>> mult2d(vector<vector<ll>> a, vector<vector<ll>> b) {
    int n = 2; while(n < szz(a) + szz(b)) n <<= 1;
    int m = 2; while(m < szz(a[0]) + szz(b[0])) m <<= 1;
    a.resize(n); b.resize(n);
    for(int i = 0; i < n; i++) a[i].resize(m);
    for(int i = 0; i < n; i++) b[i].resize(m);
    ntt2d(a, false); ntt2d(b, false);
    for(int i = 0; i < n; i++) for(int j = 0; j < m; j++) a[i][j] = a[i][j] * b[i][j] % mod;
    ntt2d(a, true);
    return a;
  }

  vector<ll> mult(vector<ll> &a, vector<ll> &b){
      vector<ll> fa(all(a)), fb(all(b));
      int n = 2; while(n < szz(a) + szz(b)) n <<= 1;
      fa.resize(n); fb.resize(n);
      ntt(fa, false); ntt(fb, false);
      for(int i = 0; i < n; i++) fa[i] = fa[i] * fb[i] % mod;
      ntt(fa, true);
      return fa;
  }
};

constexpr ll mod1 = 998244353, mod2 = 7 << 26 | 1;
NTT<mod1> n1;
NTT<mod2> n2; 
vector<vector<ll>> a, b;
int n, m, r, c;
const ll tmods[5] = {1000000007, 1000000009, 998244353, 1000000033, 1000000021};
using L = __int128_t;

void merge_by_crt(vector<vector<ll>> &a, vector<vector<ll>> &b) {
  ll inv1 = n2.multInv(mod1), inv2 = n1.multInv(mod2), pmod = mod1 * mod2;
  for(int i = 0; i < szz(a); i++) {
    for(int j = 0; j < szz(b); j++) {
      a[i][j] = (
        __int128_t(a[i][j]) * mod2 * inv2 + __int128_t(b[i][j]) * mod1 * inv1
      ) % pmod;
    }
  }
}

int main(){
  fio;
  cin >> n >> m;
  a = vector<vector<ll>>(n, vector<ll>(m));
  for(int i = 0; i < n; i++) for(int j = 0; j < m; j++) cin >> a[i][j];
  cin >> r >> c;
  b = vector<vector<ll>>(r, vector<ll>(c));
  for(int i = 0; i < r; i++) for(int j = 0; j < c; j++) cin >> b[r-1-i][c-1-j];
  auto ab = n1.mult2d(a, b);
  auto ab2 = n2.mult2d(a, b);
  merge_by_crt(ab, ab2);

  #ifdef TAMREF
  cerr << ab << '\n';
  #endif
  __int128_t asum = 0, a2sum = 0, bsum = 0, b2sum = 0;

  for(int i = 0; i < r; i++) {
    for(int j = 0; j < c; j++) {
      bsum += b[i][j];
      b2sum += b[i][j] * b[i][j];
    }
  }

  ll ans = 0;


  auto compare = [&](int i, int j) {
    bool isConstA = !(a2sum * r * c - asum * asum), 
         isConstB = !(b2sum * r * c - bsum * bsum);

    if(isConstA && !isConstB) return;
    if(isConstA && isConstB) {
      ++ans;
      return;
    }


    L bvar = (b2sum * r * c - bsum * bsum);
    L cov = (L(ab[i + r - 1][j + c - 1]) * r * c - asum * bsum);
    L avar =(a2sum * r * c - asum * asum);
    if(!avar && bvar) return;
    if(!avar && !bvar) {
      ++ans;
      return;
    }

    bool flag = true;
    for(const ll mod : tmods) {
        ll _c = (mod + cov % mod) % mod;
        ll _a = avar % mod, _b = bvar % mod;
        ll diff = (mod + (_c * _c - _a * _b) % mod) % mod;
        if(diff) {
            flag = false;
            break;
        }
    }
    if(flag) ++ans;
    debug("ans = %lld\n", ans);
  };

  for(int k = 0; k + r - 1 < n; k++) {
    asum = a2sum = 0;
    for(int i = 0; i < r; i++) {
      for(int j = 0; j < c; j++) {
        asum += a[k+i][j];
        a2sum += a[k+i][j] * a[k+i][j];
      }
    }

    compare(k, 0);
    for(int j = 1; j + c - 1 < m; j++) {
      for(int i = 0; i < r; i++) {
        asum += a[k+i][j+c-1] - a[k+i][j-1];
        a2sum += a[k+i][j+c-1] * a[k+i][j+c-1] - a[k+i][j-1] * a[k+i][j-1];
      }
      compare(k, j);
    }
  }

  cout << ans << '\n';
}

Compilation message

G.cpp: In lambda function:
G.cpp:19:20: warning: statement has no effect [-Wunused-value]
   19 | #define debug(...) 42
      |                    ^~
G.cpp:247:5: note: in expansion of macro 'debug'
  247 |     debug("ans = %lld\n", ans);
      |     ^~~~~
# 결과 실행 시간 메모리 Grader output
1 Correct 6 ms 852 KB Output is correct
2 Correct 6 ms 852 KB Output is correct
3 Correct 7 ms 852 KB Output is correct
4 Correct 8 ms 976 KB Output is correct
5 Correct 7 ms 980 KB Output is correct
6 Correct 7 ms 980 KB Output is correct
7 Correct 6 ms 980 KB Output is correct
8 Correct 6 ms 980 KB Output is correct
9 Correct 7 ms 1108 KB Output is correct
10 Correct 6 ms 976 KB Output is correct
11 Correct 7 ms 852 KB Output is correct
12 Correct 6 ms 940 KB Output is correct
13 Correct 7 ms 960 KB Output is correct
14 Correct 9 ms 980 KB Output is correct
15 Correct 8 ms 980 KB Output is correct
16 Correct 6 ms 980 KB Output is correct
17 Correct 6 ms 976 KB Output is correct
18 Correct 7 ms 976 KB Output is correct
19 Correct 7 ms 980 KB Output is correct
20 Correct 7 ms 956 KB Output is correct
21 Correct 8 ms 980 KB Output is correct
22 Correct 9 ms 980 KB Output is correct
23 Correct 7 ms 928 KB Output is correct
24 Correct 7 ms 980 KB Output is correct
25 Correct 6 ms 980 KB Output is correct
26 Correct 7 ms 980 KB Output is correct
27 Correct 6 ms 848 KB Output is correct
28 Correct 7 ms 968 KB Output is correct
29 Correct 6 ms 972 KB Output is correct
30 Correct 8 ms 852 KB Output is correct
31 Correct 7 ms 980 KB Output is correct
32 Correct 6 ms 980 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1676 ms 107052 KB Output is correct
2 Correct 1738 ms 106984 KB Output is correct
3 Correct 1801 ms 107088 KB Output is correct
4 Correct 1792 ms 107640 KB Output is correct
5 Correct 1842 ms 107640 KB Output is correct
6 Correct 1798 ms 107788 KB Output is correct
7 Correct 1749 ms 107536 KB Output is correct
8 Correct 1736 ms 107488 KB Output is correct
9 Correct 1811 ms 107572 KB Output is correct
10 Correct 1834 ms 107544 KB Output is correct
11 Correct 1845 ms 107716 KB Output is correct
12 Correct 1898 ms 107612 KB Output is correct
13 Correct 1728 ms 107500 KB Output is correct
14 Correct 1733 ms 107624 KB Output is correct
15 Correct 1796 ms 113004 KB Output is correct
16 Correct 2222 ms 116136 KB Output is correct
17 Correct 2198 ms 116144 KB Output is correct
18 Correct 2122 ms 113964 KB Output is correct
19 Correct 2327 ms 115192 KB Output is correct
20 Correct 2157 ms 116228 KB Output is correct
21 Correct 2284 ms 115508 KB Output is correct
22 Correct 1688 ms 126552 KB Output is correct
23 Correct 2053 ms 114896 KB Output is correct
24 Correct 2024 ms 114724 KB Output is correct
25 Correct 2088 ms 115716 KB Output is correct
26 Correct 2031 ms 112040 KB Output is correct
27 Correct 2058 ms 115924 KB Output is correct
28 Correct 1802 ms 112596 KB Output is correct
29 Correct 1695 ms 112444 KB Output is correct
30 Correct 489 ms 48848 KB Output is correct
31 Correct 451 ms 47324 KB Output is correct
32 Correct 2025 ms 114268 KB Output is correct