Submission #1253450

#TimeUsernameProblemLanguageResultExecution timeMemory
1253450ankiteTriple Jump (JOI19_jumps)Pypy 3
27 / 100
555 ms255380 KiB
import math
import sys

class SegmentTree:
    def __init__(self, n):
        self.n = n
        self.size = 1
        while self.size < n:
            self.size *= 2
        self.data = [-10**18] * (2 * self.size)
    
    def update(self, i, val):
        i += self.size
        if val > self.data[i]:
            self.data[i] = val
        i //= 2
        while i:
            self.data[i] = max(self.data[2*i], self.data[2*i+1])
            i //= 2
            
    def query(self, l, r):
        l += self.size
        r += self.size
        res = -10**18
        while l <= r:
            if l % 2 == 1:
                res = max(res, self.data[l])
                l += 1
            if r % 2 == 0:
                res = max(res, self.data[r])
                r -= 1
            l //= 2
            r //= 2
        return res

def main():
    data = sys.stdin.read().split()
    if not data:
        return
    
    n = int(data[0])
    A = list(map(int, data[1:1+n]))
    q = int(data[1+n])
    queries = []
    index = 1+n+1
    for i in range(q):
        L = int(data[index]); R = int(data[index+1]); index += 2
        queries.append((L-1, R-1, i))
    
    LOG = (n).bit_length()
    st_table = [[(0,0)] * n for _ in range(LOG)]
    
    for i in range(n):
        st_table[0][i] = (A[i], i)
        
    for j in range(1, LOG):
        step = 1 << (j-1)
        for i in range(n - (1<<j) + 1):
            left_val, left_idx = st_table[j-1][i]
            right_val, right_idx = st_table[j-1][i+step]
            if left_val > right_val:
                st_table[j][i] = (left_val, left_idx)
            elif left_val < right_val:
                st_table[j][i] = (right_val, right_idx)
            else:
                st_table[j][i] = (left_val, min(left_idx, right_idx))
                
    def query_range_max(l, r):
        if l > r:
            return (-10**18, -1)
        length = r - l + 1
        j = length.bit_length() - 1
        if j >= LOG:
            j = LOG-1
        seg1 = st_table[j][l]
        seg2 = st_table[j][r - (1<<j) + 1]
        if seg1[0] > seg2[0]:
            return seg1
        elif seg1[0] < seg2[0]:
            return seg2
        else:
            return (seg1[0], min(seg1[1], seg2[1]))
    
    nxt = [[] for _ in range(n)]
    stack = []
    for i in range(n):
        while stack and A[stack[-1]] < A[i]:
            top = stack.pop()
            nxt[top].append(i)
        if stack:
            nxt[stack[-1]].append(i)
        stack.append(i)
    
    events = []
    for i in range(n):
        for j in nxt[i]:
            k_low = 2*j - i
            if k_low < n:
                if k_low <= n-1:
                    val, idx = query_range_max(k_low, n-1)
                    total_val = A[i] + A[j] + val
                    events.append((idx, i, total_val))
    
    events.sort(key=lambda x: x[0])
    queries_sorted = sorted(queries, key=lambda x: x[1])
    
    seg_tree = SegmentTree(n)
    ans = [-10**18] * q
    event_ptr = 0
    query_ptr = 0
    
    for R in range(n):
        while event_ptr < len(events) and events[event_ptr][0] <= R:
            k, i, val = events[event_ptr]
            seg_tree.update(i, val)
            event_ptr += 1
            
        while query_ptr < len(queries_sorted) and queries_sorted[query_ptr][1] == R:
            l0, r0, idx = queries_sorted[query_ptr]
            res = seg_tree.query(l0, n-1)
            ans[idx] = res
            query_ptr += 1
            
    for i in range(q):
        print(ans[i])

if __name__ == '__main__':
    main()

Compilation message (stdout)

Compiling 'jumps.py'...

=======
  adding: __main__.pyc (deflated 47%)

=======
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...