Submission #32608

#TimeUsernameProblemLanguageResultExecution timeMemory
32608ho94949씽크스몰 (kriii3_TT)C++14
30 / 30
6979 ms278140 KiB
#include<bits/stdc++.h> #include <smmintrin.h> using namespace std; #pragma GCC target("avx2") #pragma GCC target("fma") //Who might read this code //This code uses AVX2 operation for fast FFT calculation. //This depends on judging environment, and should be used carefully //simple debug code of "print" operation void print(__m256d a) { cout << a[0] << " "; cout << a[1] << " "; cout << a[2] << " "; cout << a[3] << endl; } /* Flow of function "mult" a, b consist of two complex number. Return value is multiple of each complex number. Required operations: 6 +-------+----movedup----+-------+ +-------+-------+-------+-------+ +-------+----shuffle----+-------+ | a1R | a1R | a2R | a2R | <- | a1R | a1I | a2R | a2I | -> | a1I | a1I | a2I | a2I | +-------+-------+-------+-------+ +-------+-------+-------+-------+ +-------+-------+-------+-------+ | | v v +-------+------mul------+-------+ +-------+-------+-------+-------+ +-------+------mul------+-------+ |a1R*b1R|a1R*b1I|a2R*b2R|a2R*b2I| <- | b1R | b1I | b2R | b2I | -> |a1I*b1R|a1I*b1I|a2I*b2R|a2I*b2I| +-------+-------+-------+-------+ +-------+-------+-------+-------+ +-------+-------+-------+-------+ | - + - + | | +-------+----shuffle----+-------+ | |a1I*b1I|a1I*b1R|a2I*b2I|a2I*b2R| <-------------------------------------------------------+ +-------+-------+-------+-------+ | addsub v +---------------+---------------+---------------+---------------+ |a1R*b1R-a1I*b1I|a1R*b1I+a1I*b1R|a2R*b2R-a2I*b2I|a2R*b2I+a2I*b2R| +---------------+---------------+---------------+---------------+ */ __m256d mult(__m256d a, __m256d b) { __m256d c = _mm256_movedup_pd(a); __m256d d = _mm256_shuffle_pd(a, a, 15); __m256d cb = _mm256_mul_pd(c, b); __m256d db = _mm256_mul_pd(d, b); __m256d e = _mm256_shuffle_pd(db, db, 5); __m256d r = _mm256_addsub_pd(cb, e); return r; } /* Just plain FFT operation. kth power of root of unity is calculated by w * wlen, not precalculated one. To reduce memory access and gain speed for two complex multiplication. And also reduce for calling sincos system call. */ void fft(int n, __m128d a[], bool invert) { //xor bit-reversal unsigned int sl = 32 - __builtin_ctz(n); for(unsigned i=0,j=0; i<n-1;++i) { if(i<j) swap(a[i], a[j]); unsigned t = i^(i+1); t <<= __builtin_clz(t); t >>= sl; j ^= t; } for(int len=2; len<=n; len<<=1) { double ang = 2*3.14159265358979/len*(invert?-1:1); __m256d wlen; wlen[0] = cos(ang), wlen[1] = sin(ang); for(int i=0; i<n; i += len) { __m256d w; w[0] = 1; w[1] = 0; for(int j=0; j<len/2; ++j) { w = _mm256_permute2f128_pd(w, w, 0); wlen = _mm256_insertf128_pd(wlen, a[i+j+len/2], 1); w = mult(w, wlen); __m128d vw = _mm256_extractf128_pd(w, 1); __m128d u = a[i+j]; a[i+j] = _mm_add_pd(u, vw); a[i+j+len/2] = _mm_sub_pd(u, vw); } } } if(invert) { __m128d inv; inv[0] = inv[1] = 1.0/n; for(int i=0; i<n; ++i) a[i] = _mm_mul_pd(a[i], inv); } } vector<int64_t> multiply(vector<int64_t>& v, vector<int64_t>& w) { //v+w-1 is enough for size. odd n will cause error. int n = 2; while(n < v.size()+w.size()-1) n<<=1; __m128d* fv = new __m128d[n]; __m128d* fw = new __m128d[n]; for(int i=0; i<n; ++i) fv[i][0] = fv[i][1] = fw[i][0] = fw[i][1] = 0; for(int i=0; i<v.size(); ++i) fv[i][0] = v[i]; for(int i=0; i<w.size(); ++i) fw[i][0] = w[i]; fft(n, fv, 0); fft(n, fw, 0); for(int i=0; i<n; i += 2) { __m256d a, b; a = _mm256_insertf128_pd(a, fv[i], 0); a = _mm256_insertf128_pd(a, fv[i+1], 1); b = _mm256_insertf128_pd(b, fw[i], 0); b = _mm256_insertf128_pd(b, fw[i+1], 1); a = mult(a, b); fv[i] = _mm256_extractf128_pd(a, 0); fv[i+1] = _mm256_extractf128_pd(a, 1); } fft(n, fv, 1); vector<int64_t> ret(n); for(int i=0; i<n; ++i) ret[i] = (int64_t)round(fv[i][0]); delete[] fv; delete[] fw; return ret; } static char _buffer[1 << 19]; // Can be changed static int _currentChar = 0; static int _charsNumber = 0; static inline int _read() { if (_charsNumber < 0) exit(1); if (!_charsNumber || _currentChar == _charsNumber) { _charsNumber = (int)fread(_buffer, sizeof(_buffer[0]), sizeof(_buffer), stdin); _currentChar = 0; } if (_charsNumber <= 0) return -1; return _buffer[_currentChar++]; } static inline int _readInt() { int c, x, s; c = _read(); while (c <= 32) c = _read(); x = 0; s = 1; if (c == '-') { s = -1; c = _read(); } while (c > 32) { x *= 10; x += c - '0'; c = _read(); } if (s < 0) x = -x; return x; } int main() { int N = _readInt()+1; int M = _readInt()+1; vector<int64_t> _a(N), _b(M); for(int i=0; i<N; ++i) _a[i] = _readInt(); for(int i=0; i<M; ++i) _b[i] = _readInt(); vector<int64_t> a[3], b[3]; for(int i=0; i<3; ++i) { a[i].resize(N); b[i].resize(M); for(int j=0; j<N; ++j) { a[i][j] = _a[j] & 127; _a[j] >>= 7; } for(int j=0; j<M; ++j) { b[i][j] = _b[j] & 127; _b[j] >>= 7; } } vector<int64_t> res[3][3]; for(int i=0; i<3; ++i) for(int j=0; j<3; ++j) { res[i][j] = multiply(a[i], b[j]); } int64_t ans = 0; for(int i=0; i<res[0][0].size(); ++i) { int64_t t = 0; for(int j=0; j<3; ++j) for(int k=0; k<3; ++k) t += res[j][k][i] << (7*(j+k)); ans ^= t; } std::cout << ans << endl; return 0; }

Compilation message (stderr)

tt.cpp: In function 'void fft(int, __m128d*, bool)':
tt.cpp:69:25: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
  for(unsigned i=0,j=0; i<n-1;++i)
                         ^
tt.cpp: In function 'std::vector<long int> multiply(std::vector<long int>&, std::vector<long int>&)':
tt.cpp:107:24: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     int n = 2; while(n < v.size()+w.size()-1) n<<=1; 
                        ^
tt.cpp:113:19: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     for(int i=0; i<v.size(); ++i)
                   ^
tt.cpp:115:19: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     for(int i=0; i<w.size(); ++i)
                   ^
tt.cpp: In function 'int main()':
tt.cpp:205:19: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     for(int i=0; i<res[0][0].size(); ++i)
                   ^
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...