#include<bits/stdc++.h>
#include <smmintrin.h>
using namespace std;
#pragma GCC target("avx2")
#pragma GCC target("fma")
//Who might read this code
//This code uses AVX2 operation for fast FFT calculation.
//This depends on judging environment, and should be used carefully
//simple debug code of "print" operation
void print(__m256d a)
{
cout << a[0] << " ";
cout << a[1] << " ";
cout << a[2] << " ";
cout << a[3] << endl;
}
/*
Flow of function "mult" a, b consist of two complex number.
Return value is multiple of each complex number.
Required operations: 6
+-------+----movedup----+-------+ +-------+-------+-------+-------+ +-------+----shuffle----+-------+
| a1R | a1R | a2R | a2R | <- | a1R | a1I | a2R | a2I | -> | a1I | a1I | a2I | a2I |
+-------+-------+-------+-------+ +-------+-------+-------+-------+ +-------+-------+-------+-------+
| |
v v
+-------+------mul------+-------+ +-------+-------+-------+-------+ +-------+------mul------+-------+
|a1R*b1R|a1R*b1I|a2R*b2R|a2R*b2I| <- | b1R | b1I | b2R | b2I | -> |a1I*b1R|a1I*b1I|a2I*b2R|a2I*b2I|
+-------+-------+-------+-------+ +-------+-------+-------+-------+ +-------+-------+-------+-------+
|
- + - + |
|
+-------+----shuffle----+-------+ |
|a1I*b1I|a1I*b1R|a2I*b2I|a2I*b2R| <-------------------------------------------------------+
+-------+-------+-------+-------+
| addsub
v
+---------------+---------------+---------------+---------------+
|a1R*b1R-a1I*b1I|a1R*b1I+a1I*b1R|a2R*b2R-a2I*b2I|a2R*b2I+a2I*b2R|
+---------------+---------------+---------------+---------------+
*/
__m256d mult(__m256d a, __m256d b)
{
__m256d c = _mm256_movedup_pd(a);
__m256d d = _mm256_shuffle_pd(a, a, 15);
__m256d cb = _mm256_mul_pd(c, b);
__m256d db = _mm256_mul_pd(d, b);
__m256d e = _mm256_shuffle_pd(db, db, 5);
__m256d r = _mm256_addsub_pd(cb, e);
return r;
}
/*
Just plain FFT operation.
kth power of root of unity is calculated by w * wlen, not precalculated one.
To reduce memory access and gain speed for two complex multiplication.
And also reduce for calling sincos system call.
*/
void fft(int n, __m128d a[], bool invert)
{
//xor bit-reversal
unsigned int sl = 32 - __builtin_ctz(n);
for(unsigned i=0,j=0; i<n-1;++i)
{
if(i<j) swap(a[i], a[j]);
unsigned t = i^(i+1);
t <<= __builtin_clz(t);
t >>= sl;
j ^= t;
}
for(int len=2; len<=n; len<<=1)
{
double ang = 2*3.14159265358979/len*(invert?-1:1);
__m256d wlen; wlen[0] = cos(ang), wlen[1] = sin(ang);
for(int i=0; i<n; i += len)
{
__m256d w; w[0] = 1; w[1] = 0;
for(int j=0; j<len/2; ++j)
{
w = _mm256_permute2f128_pd(w, w, 0);
wlen = _mm256_insertf128_pd(wlen, a[i+j+len/2], 1);
w = mult(w, wlen);
__m128d vw = _mm256_extractf128_pd(w, 1);
__m128d u = a[i+j];
a[i+j] = _mm_add_pd(u, vw);
a[i+j+len/2] = _mm_sub_pd(u, vw);
}
}
}
if(invert)
{
__m128d inv; inv[0] = inv[1] = 1.0/n;
for(int i=0; i<n; ++i)
a[i] = _mm_mul_pd(a[i], inv);
}
}
vector<int64_t> multiply(vector<int64_t>& v, vector<int64_t>& w)
{
//v+w-1 is enough for size. odd n will cause error.
int n = 2; while(n < v.size()+w.size()-1) n<<=1;
__m128d* fv = new __m128d[n];
__m128d* fw = new __m128d[n];
for(int i=0; i<n; ++i)
fv[i][0] = fv[i][1] = fw[i][0] = fw[i][1] = 0;
for(int i=0; i<v.size(); ++i)
fv[i][0] = v[i];
for(int i=0; i<w.size(); ++i)
fw[i][0] = w[i];
fft(n, fv, 0); fft(n, fw, 0);
for(int i=0; i<n; i += 2)
{
__m256d a, b;
a = _mm256_insertf128_pd(a, fv[i], 0);
a = _mm256_insertf128_pd(a, fv[i+1], 1);
b = _mm256_insertf128_pd(b, fw[i], 0);
b = _mm256_insertf128_pd(b, fw[i+1], 1);
a = mult(a, b);
fv[i] = _mm256_extractf128_pd(a, 0);
fv[i+1] = _mm256_extractf128_pd(a, 1);
}
fft(n, fv, 1);
vector<int64_t> ret(n);
for(int i=0; i<n; ++i) ret[i] = (int64_t)round(fv[i][0]);
delete[] fv;
delete[] fw;
return ret;
}
static char _buffer[1 << 19]; // Can be changed
static int _currentChar = 0;
static int _charsNumber = 0;
static inline int _read() {
if (_charsNumber < 0) exit(1);
if (!_charsNumber || _currentChar == _charsNumber) {
_charsNumber = (int)fread(_buffer, sizeof(_buffer[0]), sizeof(_buffer), stdin);
_currentChar = 0;
}
if (_charsNumber <= 0) return -1;
return _buffer[_currentChar++];
}
static inline int _readInt() {
int c, x, s;
c = _read();
while (c <= 32) c = _read();
x = 0;
s = 1;
if (c == '-') {
s = -1;
c = _read();
}
while (c > 32) {
x *= 10;
x += c - '0';
c = _read();
}
if (s < 0) x = -x;
return x;
}
int main()
{
int N = _readInt()+1; int M = _readInt()+1;
vector<int64_t> _a(N), _b(M);
for(int i=0; i<N; ++i) _a[i] = _readInt();
for(int i=0; i<M; ++i) _b[i] = _readInt();
vector<int64_t> a[3], b[3];
for(int i=0; i<3; ++i)
{
a[i].resize(N);
b[i].resize(M);
for(int j=0; j<N; ++j)
{
a[i][j] = _a[j] & 127;
_a[j] >>= 7;
}
for(int j=0; j<M; ++j)
{
b[i][j] = _b[j] & 127;
_b[j] >>= 7;
}
}
vector<int64_t> res[3][3];
for(int i=0; i<3; ++i)
for(int j=0; j<3; ++j)
{
res[i][j] = multiply(a[i], b[i]);
}
int64_t ans = 0;
for(int i=0; i<res[0][0].size(); ++i)
{
int64_t t = 0;
for(int j=0; j<3; ++j)
for(int k=0; k<3; ++k)
t += res[j][k][i] << (7*(j+k));
ans ^= t;
}
std::cout << ans << endl;
return 0;
}
Compilation message
tt.cpp: In function 'void fft(int, __m128d*, bool)':
tt.cpp:69:25: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
for(unsigned i=0,j=0; i<n-1;++i)
^
tt.cpp: In function 'std::vector<long int> multiply(std::vector<long int>&, std::vector<long int>&)':
tt.cpp:107:24: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
int n = 2; while(n < v.size()+w.size()-1) n<<=1;
^
tt.cpp:113:19: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
for(int i=0; i<v.size(); ++i)
^
tt.cpp:115:19: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
for(int i=0; i<w.size(); ++i)
^
tt.cpp: In function 'int main()':
tt.cpp:205:19: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
for(int i=0; i<res[0][0].size(); ++i)
^
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
0 ms |
2576 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
53 ms |
6756 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1053 ms |
65648 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |