Submission #1366989

#TimeUsernameProblemLanguageResultExecution timeMemory
1366989makonWizdomiot (KAISTRUN26SPRING_I)Pypy 3
100 / 100
2305 ms302252 KiB
import sys
input = sys.stdin.readline

MOD = 998244353
I2 = 499122177
I6 = 166374059

class ST:
    __slots__ = ("n", "sz", "t")

    def __init__(self, n):
        self.n = n
        sz = 1
        while sz < n:
            sz <<= 1
        self.sz = sz
        self.t = [0] * (sz << 1)

    def upd(self, p, x):
        i = self.sz + p - 1
        t = self.t
        if t[i] == x:
            return
        t[i] = x
        i >>= 1
        while i:
            v = t[i << 1]
            if t[i << 1 | 1] > v:
                v = t[i << 1 | 1]
            if t[i] == v:
                break
            t[i] = v
            i >>= 1

    def prv(self, r, x):
        if r <= 0:
            return 0
        t = self.t
        sz = self.sz
        i = sz + r - 1

        if t[i] > x:
            return r

        while i > 1:
            if i & 1:
                j = i - 1
                if t[j] > x:
                    while j < sz:
                        rc = j << 1 | 1
                        if t[rc] > x:
                            j = rc
                        else:
                            j = rc - 1
                    p = j - sz + 1
                    return p if p <= self.n else 0
            i >>= 1

        return 0

    def nxt(self, l, x):
        if l > self.n:
            return 0
        t = self.t
        sz = self.sz
        i = sz + l - 1

        if t[i] > x:
            return l

        while i > 1:
            if not (i & 1):
                j = i + 1
                if t[j] > x:
                    while j < sz:
                        lc = j << 1
                        if t[lc] > x:
                            j = lc
                        else:
                            j = lc | 1
                    p = j - sz + 1
                    return p if p <= self.n else 0
            i >>= 1

        return 0

def sm(l, r, u, v, e, a):
    c = (r - l + 1) % MOD
    s1 = ((l + r) % MOD) * c % MOD * I2 % MOD

    rr = r % MOD
    ll = (l - 1) % MOD
    s2 = (
        rr * ((rr + 1) % MOD) % MOD * ((2 * rr + 1) % MOD)
        - ll * ((ll + 1) % MOD) % MOD * ((2 * ll + 1) % MOD)
    ) % MOD * I6 % MOD

    ua, ub = u
    va, vb = v
    ea, eb = e
    aa, ab = a

    su = (ua * s1 + ub * c) % MOD
    sv = (va * s1 + vb * c) % MOD

    su2 = (ua * ua * s2 + 2 * ua * ub * s1 + ub * ub * c) % MOD
    sv2 = (va * va * s2 + 2 * va * vb * s1 + vb * vb * c) % MOD

    av = (aa * va * s2 + (aa * vb + ab * va) * s1 + ab * vb * c) % MOD
    eu = (ea * ua * s2 + (ea * ub + eb * ua) * s1 + eb * ub * c) % MOD
    ae = (aa * ea * s2 + (aa * eb + ab * ea) * s1 + ab * eb * c) % MOD

    return ((sv2 - su2 + su + sv) * I2 - sv2 - sv + av + eu - ae) % MOD

class DS:
    __slots__ = ("n", "on", "p", "l", "r", "w", "sr", "er", "sl", "sg", "cur")

    def __init__(self, n):
        self.n = n
        self.on = bytearray(n + 2)
        self.p = list(range(n + 2))
        self.l = [0] * (n + 2)
        self.r = [0] * (n + 2)
        self.w = [0] * (n + 2)

        self.sr = [0] * (n + 2)
        self.er = [0] * (n + 2)

        self.sl = ST(n)
        self.sg = ST(n)

        self.cur = 0

    def fd(self, x):
        p = self.p
        while p[x] != x:
            p[x] = p[p[x]]
            x = p[x]
        return x

    def gl(self, lim, b, c):
        ret = []
        pm = b

        sg = self.sg
        er = self.er
        w = self.w
        l = self.l

        while pm < c:
            q = sg.prv(lim, pm)
            if not q:
                ret.append((pm + 1, c, (0, 0)))
                break

            rt = er[q]
            ln = w[rt]
            hi = ln if ln < c else c

            ret.append((pm + 1, hi, (-1, q + 1)))

            pm = ln
            lim = l[rt] - 1

        return ret

    def gr(self, lim, b, c):
        ret = []
        pm = b
        n = self.n

        sl = self.sl
        sr = self.sr
        w = self.w
        r = self.r

        while pm < c:
            s = sl.nxt(lim, pm)
            if not s:
                ret.append((pm + 1, c, (-1, n + 2)))
                break

            rt = sr[s]
            ln = w[rt]
            hi = ln if ln < c else c

            ret.append((pm + 1, hi, (0, s)))

            pm = ln
            lim = r[rt] + 1

        return ret

    def add(self, x):
        n = self.n
        on = self.on
        p = self.p
        w = self.w
        l = self.l
        r = self.r

        hl = x > 1 and on[x - 1]
        hr = x < n and on[x + 1]

        lr = self.fd(x - 1) if hl else 0
        rr = self.fd(x + 1) if hr else 0

        a = w[lr] if lr else 0
        b = w[rr] if rr else 0

        L = x - a
        R = x + b
        c = a + 1 + b

        el = []
        if a:
            el.append((1, a, (-1, x)))
        if a < c:
            el.extend(self.gl(L - 1, a, c))

        al = []
        if b:
            al.append((1, b, (0, x + 1)))
        if b < c:
            al.extend(self.gr(R + 1, b, c))

        h = a + 1
        if h >= c:
            ul = [(1, c, (-1, x + 1))]
        else:
            ul = [(1, h, (-1, x + 1)), (h + 1, c, (0, L))]

        h = b + 1
        if h >= c:
            vl = [(1, c, (0, x))]
        else:
            vl = [(1, h, (0, x)), (h + 1, c, (-1, R + 1))]

        i = j = k = m = 0
        t = 1
        d = 0

        while t <= c:
            e = el[i]
            aa = al[j]
            u = ul[k]
            v = vl[m]

            hi = e[1]
            if aa[1] < hi:
                hi = aa[1]
            if u[1] < hi:
                hi = u[1]
            if v[1] < hi:
                hi = v[1]

            d += sm(t, hi, u[2], v[2], e[2], aa[2])
            if d >= (1 << 63):
                d %= MOD

            t = hi + 1

            if t > e[1]:
                i += 1
            if t > aa[1]:
                j += 1
            if t > u[1]:
                k += 1
            if t > v[1]:
                m += 1

        d %= MOD

        sl = self.sl
        sg = self.sg
        sr = self.sr
        er = self.er

        if lr:
            ol = l[lr]
            orr = r[lr]

            sl.upd(ol, 0)
            sg.upd(orr, 0)

            sr[ol] = 0
            er[orr] = 0

        if rr:
            ol = l[rr]
            orr = r[rr]

            sl.upd(ol, 0)
            sg.upd(orr, 0)

            sr[ol] = 0
            er[orr] = 0

        if a >= b and a:
            rt = lr
        elif b:
            rt = rr
        else:
            rt = x

        p[x] = rt

        if lr and lr != rt:
            p[lr] = rt
        if rr and rr != rt:
            p[rr] = rt

        l[rt] = L
        r[rt] = R
        w[rt] = c

        on[x] = 1

        sr[L] = rt
        er[R] = rt

        sl.upd(L, c)
        sg.upd(R, c)

        self.cur += d
        if self.cur >= MOD:
            self.cur -= MOD

def main():
    n = int(input())
    h = list(map(int, input().split()))

    a = [(h[i], i + 1) for i in range(n)]
    a.sort(reverse=True)

    ds = DS(n)

    ans = 0
    i = 0

    while i < n:
        x = a[i][0]

        while i < n and a[i][0] == x:
            ds.add(a[i][1])
            i += 1

        y = a[i][0] if i < n else 0
        ans = (ans + (x - y) * ds.cur) % MOD

    print(ans)

if __name__ == "__main__": main()

Compilation message (stdout)

Compiling 'Main.py'...

=======
  adding: __main__.pyc (deflated 52%)

=======
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...