답안 #1086431

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1086431 2024-09-10T15:14:42 Z shikgom2 씽크스몰 (kriii3_TT) C++17
10 / 30
499 ms 30908 KB
#include <bits/stdc++.h>
#define all(v) v.begin(), v.end()
using namespace std;

typedef long long ll;
typedef vector<ll> poly;

ll pw(ll a, ll b, ll mod){
    ll ret = 1;
    while(b){
        if(b & 1) ret = ret * a % mod;
        b >>= 1; a = a * a % mod;
    }
    return ret;
}

template<ll mod, ll w>
class NTT{
public:
    void ntt(poly &f, bool inv = 0){
        int n = f.size(), j = 0;
        vector<ll> root(n >> 1);
        for(int i=1; i<n; i++){
            int bit = (n >> 1);
            while(j >= bit){
                j -= bit; bit >>= 1;
            }
            j += bit;
            if(i < j) swap(f[i], f[j]);
        }
        ll ang = pw(w, (mod - 1) / n, mod); if(inv) ang = pw(ang, mod - 2, mod);
        root[0] = 1; for(int i=1; i<(n >> 1); i++) root[i] = root[i-1] * ang % mod;
        for(int i=2; i<=n; i<<=1){
            int step = n / i;
            for(int j=0; j<n; j+=i){
                for(int k=0; k<(i >> 1); k++){
                    ll u = f[j | k], v = f[j | k | i >> 1] * root[step * k] % mod;
                    f[j | k] = (u + v) % mod;
                    f[j | k | i >> 1] = (u - v) % mod;
                    if(f[j | k | i >> 1] < 0) f[j | k | i >> 1] += mod;
                }
            }
        }
        ll t = pw(n, mod - 2, mod);
        if(inv) for(int i=0; i<n; i++) f[i] = f[i] * t % mod;
    }
    vector<ll> multiply(poly &_a, poly &_b){
        vector<ll> a(all(_a)), b(all(_b));
        int n = 2;
        while(n < a.size() + b.size()) n <<= 1;
        a.resize(n); b.resize(n);
        ntt(a); ntt(b);
        for(int i=0; i<n; i++) a[i] = a[i] * b[i] % mod;
        ntt(a, 1);
        return a;
    }
};

ll ext_gcd(ll a, ll b, ll &x, ll &y) { //ax + by = gcd(a, b)
  ll g = a; x = 1, y = 0;
  if (b) g = ext_gcd(b, a % b, y, x), y -= a / b * x;
  return g;
}

const ll m1 = 2281701377, m2 = 2483027969, m3 = 998244353;
NTT<m1, 3> ntt1;
NTT<m2, 3> ntt2;
NTT<m3, 3> ntt3;

ll f(const vector<ll> &a, ll mod){
    int sz = a.size();
    vector<ll> rmn(sz), lm(sz, 1);
    ll ans = 0, M = 1;
    vector<ll> m({m1, m2, m3}); //prime list

    for(int i=0; i<sz; i++){
        ll k = a[i] - rmn[i]; k %= m[i];
        if(k < 0) k += m[i];
        ll x, y;
        ext_gcd(lm[i], m[i], x, y);
        k *= x; k %= m[i];
        if(k < 0) k += m[i];
        ans += k * M % mod;
        ans %= mod;
        for(int t=i+1; t<sz; t++){
            rmn[t] += lm[t] * k;
            rmn[t] %= m[t];
            lm[t] *= m[i];
            lm[t] %= m[t];
        }
        M *= m[i]; M %= mod;
    }
    return ans;
}

poly multiply(poly &a, poly &b, ll mod){
    poly a1(a), a2(a), a3(a);
    poly b1(b), b2(b), b3(b);
    poly res1 = ntt1.multiply(a1, b1);
    poly res2 = ntt2.multiply(a2, b2);
    poly res3 = ntt3.multiply(a3, b3);
    poly ret(res1.size());
    for(int i=0; i<res1.size(); i++){
        ret[i] = f({res1[i], res2[i], res3[i]}, mod);
    }
    return ret;
}

int main() {
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int n, m;
    cin >> n >> m;
    n++, m++;
    poly a(n), b(m);
    for (int i = 0; i < n; ++i) cin >> a[i];
    for (int i = 0; i < m; ++i) cin >> b[i];
    
    int mod = 1e9+7;
    poly res = multiply(a, b, mod);
    ll ans = 0;
    for (ll v : res) ans ^= v;
    cout << ans << '\n';
}

Compilation message

tt.cpp: In function 'poly multiply(poly&, poly&, ll)':
tt.cpp:103:19: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  103 |     for(int i=0; i<res1.size(); i++){
      |                  ~^~~~~~~~~~~~
tt.cpp: In instantiation of 'std::vector<long long int> NTT<mod, w>::multiply(poly&, poly&) [with long long int mod = 2281701377; long long int w = 3; poly = std::vector<long long int>]':
tt.cpp:99:37:   required from here
tt.cpp:50:17: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   50 |         while(n < a.size() + b.size()) n <<= 1;
      |               ~~^~~~~~~~~~~~~~~~~~~~~
tt.cpp: In instantiation of 'std::vector<long long int> NTT<mod, w>::multiply(poly&, poly&) [with long long int mod = 2483027969; long long int w = 3; poly = std::vector<long long int>]':
tt.cpp:100:37:   required from here
tt.cpp:50:17: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
tt.cpp: In instantiation of 'std::vector<long long int> NTT<mod, w>::multiply(poly&, poly&) [with long long int mod = 998244353; long long int w = 3; poly = std::vector<long long int>]':
tt.cpp:101:37:   required from here
tt.cpp:50:17: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
# 결과 실행 시간 메모리 Grader output
1 Correct 0 ms 348 KB Output is correct
2 Correct 1 ms 348 KB Output is correct
3 Correct 0 ms 348 KB Output is correct
4 Correct 1 ms 456 KB Output is correct
5 Correct 2 ms 600 KB Output is correct
6 Correct 2 ms 348 KB Output is correct
7 Correct 2 ms 604 KB Output is correct
8 Correct 2 ms 604 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 27 ms 2324 KB Output is correct
2 Correct 120 ms 8604 KB Output is correct
3 Correct 119 ms 9416 KB Output is correct
4 Correct 236 ms 15632 KB Output is correct
5 Correct 245 ms 16488 KB Output is correct
6 Correct 242 ms 16072 KB Output is correct
7 Correct 241 ms 17048 KB Output is correct
8 Incorrect 222 ms 17188 KB Output isn't correct
9 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 499 ms 30908 KB Output isn't correct
2 Halted 0 ms 0 KB -