답안 #973844

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
973844 2024-05-02T11:45:57 Z canadavid1 Spiral (BOI16_spiral) PyPy 3
0 / 100
64 ms 21512 KB
position = [(0,1)]
position_map = {0:1}
def grid(x,y):
    while x + 1j * y not in position_map:
        #print(position_map)
        p,d = position[-1]
        position.append((p+d,d))
        position_map[p+d] = len(position)
        if p+d+1j*d not in position_map: position[-1] = (p+d,d*1j)
    return position_map[x + 1j * y]


# for i in range(5,-6,-1):
#     for j in range(-5,6):
#         print(grid(j,i),end="\t")
#     print()


def sum_square_tr(n):
    s = 0
    for x in range(-n,0):
        for y in range(-n,0):
            s += grid(x,y)
    return s


def eval_poly(poly,num,mod=None):
    a = round(sum(v*pow(num,i,mod) for i,v in enumerate(poly)))
    if mod is not None: a %= mod
    return a

def mul_poly(p,q):
    out = [0]*(len(p)+len(q)-1)
    for i,v in enumerate(p):
        for j,w in enumerate(q):
            out[i+j] += v*w
    return out

def add_poly(p,q):
    return [a+b for a,b in zip(p,q)]

def lagrange_poly(ord,at):
    poly = [1]
    for i in range(ord):
        if i != at:
            poly = mul_poly(poly,[-i,1])
    k = eval_poly(poly,at)
    return [i/k for i in poly]

def fit_poly(seq):
    ord = len(seq)
    out = [0]*ord
    for i in range(ord):
        out = add_poly(out,mul_poly(lagrange_poly(ord,i),[seq[i]]))
    return [i for i in out]

    
sqtr = [sum_square_tr(i) for i in range(6)]
#print(sqtr)
#print(fit_poly(sqtr))

def sgn(x): return 1 if x > 0 else 0 if x == 0 else -1

xpyp = [0,1,2,-4,2]
xnyp = [0,0,4,0,2]
xpyn = [0,-4,10,0,2]
xnyn = [0,-1,2,4,2]
diag = [[xnyn,xnyp],[xpyn,xpyp]]

xp = [0,19/6,-7/2,4/3]
xn = [0,13/6,5/2,4/3]
yp = [0,13/6,-5/2,4/3]
yn = [0,19/6,7/2,4/3]
col = [[xn,xp],[yn,yp]]
# assume x > y > 0
MOD = 10**9+7

def helper(x,y):
    n = min(abs(x),abs(y))
    N = max(abs(x),abs(y))
    ygtx = abs(x) < abs(y)
    sm = [x,y][ygtx]>=0
    sN = [y,x][ygtx]>=0
    f = n-sN
    if N-n > 0 and f > 0:
        return f,N-n
    return ()


def rect_center(x: int,y: int): # rectangle from the lower left corner of (0,0) to lower left of (x,y)
    ygtx = abs(x) < abs(y)
    sm = [x,y][ygtx]>=0
    sN = [y,x][ygtx]>=0
    smi = -1 if (ygtx != (x >= 0)) != (y >= 0) else 1
    n = min(abs(x),abs(y))
    N = max(abs(x),abs(y))
    sqp = diag[x >= 0][y >= 0]
    square = eval_poly(sqp,n,MOD)
    rp = col[ygtx][sm]
    sc: int = (eval_poly(rp,N)-eval_poly(rp,n))
    sc *= n
    f = n-sN
    sc += ((N-n)*f*(f+1))//2 * smi
    if x > 0 and y < 0 and -y < x: sc -= 8*y
    return (square + sc)%MOD

def rect_center_naive(x,y):
    s = 0
    for _x in range(min(x,0),max(x,0)):
        for _y in range(min(y,0),max(y,0)):
            s += grid(_x,_y)

    return s

def rect(xA,yA,xB,yB):
    x = sorted([xA,xB])
    x[1]+=1
    y = sorted([yA,yB])
    y[1]+=1
    
    s = 0
    for a in [0,1]:
        for b in [0,1]:
            s += rect_center(x[a],y[b]) * (-1 if x[0]*x[1]>0 and abs(x[a])<abs(x[1-a]) else 1) \
                                        * (-1 if y[0]*y[1]>0 and abs(y[b])<abs(y[1-b]) else 1)
    return s

def rect_naive(xA,yA,xB,yB):
    X = sorted([xA,xB])
    Y = sorted([yA,yB])
    s = 0
    for x in range(X[0],X[1]+1):
        for y in range(Y[0],Y[1]+1):
            s += grid(x,y)
    return s
    

#print("here")
#print(rect_center(3,1))
# xA = 2
# yA = -3
# ry = range(10,-10,-1)
# rx = range(-10,10)
# for y in ry:
#     for x in rx:
#         if abs(x)==abs(y) or x == 0 or y == 0:
#             print(("*" if rect_naive(xA,yA,x,y)==rect(xA,yA,x,y) else "#")*6,end="\t")
#         else:
#             print((rect_naive(xA,yA,x,y)-rect(xA,yA,x,y)),end="\t")
#     print()
# exit(0)
# print("-"*80)
# for y in ry:
#     for x in rx:
#         if abs(x)==abs(y) or x == 0 or y == 0:
#             print(end="******  ")
#         else:
#             print(helper(x,y),end="\t")
#     print()

_,q = map(int,input().split())
for i in range(q):
    print(rect(*map(int,input().split())))

"""
(1,1) -> 1
(1,2) -> 2
(1,n) -> n
(2,1) -> 3
(2,2) -> 6
(2,n) -> 3n
(3,n) -> 6n
(a,b) -> b*(a*(a+1)/2)

"""

""" 
diag[x >= 0][y >= 0] returns the sum of cells from (0,0) to (pm n,pm n)
col[0][x >= 0] returns the sum of cells from (0,0) to (x,1)
col[1][y >= 0] returns the sum of cells from (0,0) to (1,y)

to calculate in y > x > 0: diag[1][1] (x)
c = col[1][1] (y) - col[1][1] (x) # column from the square to the top
c * x                 # the rectangle at the top
- x * (x-1)/2 * (y-x) # extra since it is decreasing

y > -x > 0:
diag[0][1] (-x)
c = col[1][1] (y) - col[1][1] (-x) # column off to the right from 
c * -x
+ x * (x-1)/2 * (y-x)

\ +--+    /
 \|  |   /
  *--+  /
  |\ | /
  | \|/
  *--.-----
"""
# 결과 실행 시간 메모리 Grader output
1 Incorrect 60 ms 21308 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 64 ms 21512 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 60 ms 21308 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 63 ms 21288 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 60 ms 21308 KB Output isn't correct
2 Halted 0 ms 0 KB -