Submission #32607

# Submission time Handle Problem Language Result Execution time Memory
32607 2017-10-12T00:59:21 Z ho94949 씽크스몰 (kriii3_TT) C++14
0 / 30
1053 ms 65648 KB
#include<bits/stdc++.h>
#include <smmintrin.h>
using namespace std;
#pragma GCC target("avx2")
#pragma GCC target("fma")

//Who might read this code
//This code uses AVX2 operation for fast FFT calculation.
//This depends on judging environment, and should be used carefully


//simple debug code of "print" operation
void print(__m256d a)
{
    cout << a[0] << " ";
    cout << a[1] << " ";
    cout << a[2] << " ";
    cout << a[3] << endl;
}

/*

Flow of function "mult" a, b consist of two complex number.
Return value is multiple of each complex number.
Required operations: 6

    +-------+----movedup----+-------+    +-------+-------+-------+-------+    +-------+----shuffle----+-------+
    |  a1R  |  a1R  |  a2R  |  a2R  | <- |  a1R  |  a1I  |  a2R  |  a2I  | -> |  a1I  |  a1I  |  a2I  |  a2I  |
    +-------+-------+-------+-------+    +-------+-------+-------+-------+    +-------+-------+-------+-------+
                    |                                                                         |
                    v                                                                         v
    +-------+------mul------+-------+    +-------+-------+-------+-------+    +-------+------mul------+-------+
    |a1R*b1R|a1R*b1I|a2R*b2R|a2R*b2I| <- |  b1R  |  b1I  |  b2R  |  b2I  | -> |a1I*b1R|a1I*b1I|a2I*b2R|a2I*b2I|
    +-------+-------+-------+-------+    +-------+-------+-------+-------+    +-------+-------+-------+-------+
                                                                                              |
        -       +       -       +                                                             |
                                                                                              |
    +-------+----shuffle----+-------+                                                         |
    |a1I*b1I|a1I*b1R|a2I*b2I|a2I*b2R| <-------------------------------------------------------+
    +-------+-------+-------+-------+                    
                    | addsub
                    v
    +---------------+---------------+---------------+---------------+
    |a1R*b1R-a1I*b1I|a1R*b1I+a1I*b1R|a2R*b2R-a2I*b2I|a2R*b2I+a2I*b2R|
    +---------------+---------------+---------------+---------------+
*/

__m256d mult(__m256d a, __m256d b)
{
    __m256d c = _mm256_movedup_pd(a);
    __m256d d = _mm256_shuffle_pd(a, a, 15);
    __m256d cb = _mm256_mul_pd(c, b);
    __m256d db = _mm256_mul_pd(d, b);
    __m256d e = _mm256_shuffle_pd(db, db, 5);
    __m256d r = _mm256_addsub_pd(cb, e);
    return r;    
}

/*
Just plain FFT operation.
kth power of root of unity is calculated by w * wlen, not precalculated one.
To reduce memory access and gain speed for two complex multiplication.
And also reduce for calling sincos system call.
*/
void fft(int n, __m128d a[], bool invert)
{
    //xor bit-reversal
    unsigned int sl = 32 - __builtin_ctz(n);
	for(unsigned i=0,j=0; i<n-1;++i)
	{
	    if(i<j) swap(a[i], a[j]);
	    unsigned t = i^(i+1);
	    t <<= __builtin_clz(t);
	    t >>= sl;
	    j ^= t;
	}
	
    for(int len=2; len<=n; len<<=1)
    {
        double ang = 2*3.14159265358979/len*(invert?-1:1);
        __m256d wlen; wlen[0] = cos(ang), wlen[1] = sin(ang);
        for(int i=0; i<n; i += len)
        {
            __m256d w; w[0] = 1; w[1] = 0;
            for(int j=0; j<len/2; ++j)
            {
                w = _mm256_permute2f128_pd(w, w, 0);
                wlen = _mm256_insertf128_pd(wlen, a[i+j+len/2], 1);
                w = mult(w, wlen);
                __m128d vw = _mm256_extractf128_pd(w, 1);
                __m128d u = a[i+j];
                a[i+j] = _mm_add_pd(u, vw);
                a[i+j+len/2] = _mm_sub_pd(u, vw);
            }
        }
    }
    if(invert)
    {
        __m128d inv; inv[0] = inv[1] = 1.0/n;
        for(int i=0; i<n; ++i)
            a[i] = _mm_mul_pd(a[i], inv);
    }
}
vector<int64_t> multiply(vector<int64_t>& v, vector<int64_t>& w)
{
    //v+w-1 is enough for size. odd n will cause error.
    int n = 2; while(n < v.size()+w.size()-1) n<<=1; 
    __m128d* fv = new __m128d[n];
    __m128d* fw = new __m128d[n];
    
    for(int i=0; i<n; ++i)
        fv[i][0] = fv[i][1] = fw[i][0] = fw[i][1] = 0;
    for(int i=0; i<v.size(); ++i)
        fv[i][0] = v[i];
    for(int i=0; i<w.size(); ++i)
        fw[i][0] = w[i];

    fft(n, fv, 0); fft(n, fw, 0); 
    
    for(int i=0; i<n; i += 2)
    {
        __m256d a, b;
        a = _mm256_insertf128_pd(a, fv[i], 0);
        a = _mm256_insertf128_pd(a, fv[i+1], 1);
        b = _mm256_insertf128_pd(b, fw[i], 0);
        b = _mm256_insertf128_pd(b, fw[i+1], 1);
        a = mult(a, b);
        fv[i] = _mm256_extractf128_pd(a, 0);
        fv[i+1] = _mm256_extractf128_pd(a, 1);
    }
    
    fft(n, fv, 1);
    
    vector<int64_t> ret(n);
    for(int i=0; i<n; ++i) ret[i] = (int64_t)round(fv[i][0]);
    
    delete[] fv;
    delete[] fw;
    return ret;
}



static char _buffer[1 << 19]; // Can be changed
static int _currentChar = 0;
static int _charsNumber = 0;
static inline int _read() {
	if (_charsNumber < 0) exit(1);
	if (!_charsNumber || _currentChar == _charsNumber) {
		_charsNumber = (int)fread(_buffer, sizeof(_buffer[0]), sizeof(_buffer), stdin);
		_currentChar = 0;
	}
	if (_charsNumber <= 0) return -1;
	return _buffer[_currentChar++];
}
static inline int _readInt() {
	int c, x, s;
	c = _read();
	while (c <= 32) c = _read();
	x = 0;
	s = 1;
	if (c == '-') {
		s = -1;
		c = _read();
	}
	while (c > 32) {
		x *= 10;
		x += c - '0';
		c = _read();
	}
	if (s < 0) x = -x;
	return x;
}


int main()
{
    int N = _readInt()+1; int M = _readInt()+1;
    vector<int64_t> _a(N), _b(M);
    for(int i=0; i<N; ++i) _a[i] = _readInt();
    for(int i=0; i<M; ++i) _b[i] = _readInt();
    vector<int64_t> a[3], b[3];
    for(int i=0; i<3; ++i)
    {
        a[i].resize(N);
        b[i].resize(M);
        for(int j=0; j<N; ++j)
        {
            a[i][j] = _a[j] & 127;
            _a[j] >>= 7;
        }
        for(int j=0; j<M; ++j)
        {
            b[i][j] = _b[j] & 127;
            _b[j] >>= 7;
        }
    }
    vector<int64_t> res[3][3];
    for(int i=0; i<3; ++i)
        for(int j=0; j<3; ++j)
        {
            res[i][j] = multiply(a[i], b[i]);
        }
    int64_t ans = 0;
    for(int i=0; i<res[0][0].size(); ++i)
    {
        int64_t t = 0;
        for(int j=0; j<3; ++j)
            for(int k=0; k<3; ++k)
                t += res[j][k][i] << (7*(j+k));
        ans ^= t;
    }
    std::cout << ans << endl;
    return 0;
    
    
    
    
}







Compilation message

tt.cpp: In function 'void fft(int, __m128d*, bool)':
tt.cpp:69:25: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
  for(unsigned i=0,j=0; i<n-1;++i)
                         ^
tt.cpp: In function 'std::vector<long int> multiply(std::vector<long int>&, std::vector<long int>&)':
tt.cpp:107:24: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     int n = 2; while(n < v.size()+w.size()-1) n<<=1; 
                        ^
tt.cpp:113:19: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     for(int i=0; i<v.size(); ++i)
                   ^
tt.cpp:115:19: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     for(int i=0; i<w.size(); ++i)
                   ^
tt.cpp: In function 'int main()':
tt.cpp:205:19: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     for(int i=0; i<res[0][0].size(); ++i)
                   ^
# Verdict Execution time Memory Grader output
1 Incorrect 0 ms 2576 KB Output isn't correct
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Incorrect 53 ms 6756 KB Output isn't correct
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Incorrect 1053 ms 65648 KB Output isn't correct
2 Halted 0 ms 0 KB -