Submission #51783

# Submission time Handle Problem Language Result Execution time Memory
51783 2018-06-21T07:27:49 Z ho94949 씽크스몰 (kriii3_TT) PyPy
0 / 30
5678 ms 370364 KB
import math, sys
def fft(a, inv):
    N = len(a)
    j = 1
    for i in range(1, N):
        b = N >> 1
        while j >= b:
            j -= b
            b >>= 1
        j += 1
        if i<j:
            a[i], a[j] = a[j], a[i]
    ang = 2 * math.pi / N * (-1 if inv else 1)
    roots = [complex(math.cos(ang*i), math.sin(ang*i)) for i in xrange(N)]
    
    i = 2
    while i <= N:
        step = N // i
        for j in range(0, N, i):
            for k in range(0, i>>1):
                u, v = a[j+k], a[j+k+i//2]*roots[step*k]
                a[j+k] = u+v
                a[j+k+i//2] = u- v
        i <<= 1
    
    if inv:
        for i in range(N):
            a[i] /= N
    
def mult(v, w):
    N = 2
    while N < len(v) + len(w): N <<= 1
    v1 = [complex(0, 0) for i in range(N)]
    v2 = [complex(0, 0) for i in range(N)]
    r1 = [complex(0, 0) for i in range(N)]
    r2 = [complex(0, 0) for i in range(N)]
    for i in range(len(v)):
        v1[i] = complex(v[i] >> 15, v[i] & 32767)
    for i in range(len(w)):
        v2[i] = complex(w[i] >> 15, w[i] & 32767)
    fft(v1, False)
    fft(v2, False)
    
    for i in range(N):
        j = (N-i) if i!=0 else i
        ans1 = (v1[i]+v1[j].conjugate()) * complex(0.5, 0)
        ans2 = (v1[i]-v1[j].conjugate()) * complex(0, -0.5)
        ans3 = (v2[i]+v2[j].conjugate()) * complex(0.5, 0)
        ans4 = (v2[i]-v2[j].conjugate()) * complex(0, -0.5)
        r1[i] = (ans1*ans3)+(ans1*ans4) * complex(0, 1)
        r2[i] = (ans2*ans3)+(ans2*ans4) * complex(0, 1)
    
    fft(r1, True)
    fft(r2, True)
    
    ret = [0 for i in range(N)]
    for i in range(N):
        av = int(round(r1[i].real))
        bv = int(round(r1[i].imag))+int(round(r2[i].real))
        cv = int(round(r2[i].imag))
        ret[i] = (av<<30)+(bv<<15)+cv
    
    return ret

def main():
    N, M = map(int, sys.stdin.readline().split())
    X = map(int, sys.stdin.readline().split())
    Y = map(int, sys.stdin.readline().split())
    Z = mult(X, Y)
    ans = 0
    for i in Z:
        ans ^= i
    sys.stdout.write(str(ans)+"\n")
    return

if __name__ == '__main__':
    main()
# Verdict Execution time Memory Grader output
1 Correct 32 ms 11748 KB Output is correct
2 Incorrect 32 ms 11748 KB Output isn't correct
3 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Incorrect 343 ms 40648 KB Output isn't correct
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Incorrect 5678 ms 370364 KB Output isn't correct
2 Halted 0 ms 0 KB -