답안 #649166

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
649166 2022-10-09T13:05:01 Z boris_mihov Weirdtree (RMI21_weirdtree) C++17
21 / 100
2000 ms 51828 KB
#include "weirdtree.h"
#include <algorithm>
#include <iostream>
#include <numeric>
#include <cassert>
#include <vector>

typedef long long llong;
const int MAXN = 300000 + 10;
const int INF  = 2e9;

int n, q;
int a[MAXN];
struct Number
{
    llong value, count;
    Number()
    {
        value = -1;
        count = 0;
    }

    Number(llong _value, llong _count)
    {
        value = _value;
        count = _count;
    }

    inline friend bool operator < (Number a, Number b)
    {
        return a.value < b.value;
    }

    inline friend bool operator == (Number a, Number b)
    {
        return a.value == b.value;
    }

    inline friend bool operator != (Number a, Number b)
    {
        return !(a == b);
    }

    inline friend bool operator > (Number a, Number b)
    {
        return a.value > b.value;
    }

    inline friend void operator += (Number &a, Number b)
    {
        if (a.value == b.value) a.count += b.count;
        if (a.value < b.value) a = b;
    }
};

struct Node
{
    Number max, max2;  
    llong sum;
    Node()
    {
        max = max2 = {-1, 0};
        sum = 0;
    }

} tree[4*MAXN];

Node combine(Node left, Node right)
{
    Node result;
    result.sum = left.sum + right.sum;
    
    result.max = left.max;
    result.max += right.max;

    if (left.max != result.max)
    {
        result.max2 = left.max;
    } else
    {
        result.max2 = left.max2;
    }

    if (right.max != result.max)
    {
        result.max2 += right.max;
    } else
    {
        result.max2 += right.max2;
    }

    return result;
}

void build(int l, int r, int node)
{
    if (l == r)
    {
        tree[node].max2 = {-1, 0};
        tree[node].max = {a[l], 1};
        tree[node].sum = a[l];
        return;
    }

    int mid = (l + r) / 2;
    build(l, mid, 2*node);
    build(mid + 1, r, 2*node + 1);
    tree[node] = combine(tree[2*node], tree[2*node + 1]);
}

Node query(int l, int r, int node, int queryL, int queryR)
{
    if (queryL <= l && r <= queryR)
    {
        return tree[node];
    }

    Node res;
    int mid = (l + r) / 2;
    if (queryL <= mid) res = combine(res, query(l, mid, 2*node, queryL, queryR));
    if (mid + 1 <= queryR) res = combine(res, query(mid + 1, r, 2*node + 1, queryL, queryR));
    return res;
}

void updatePos(int l, int r, int node, int queryPos, int queryVal)
{
    if (l == r)
    {
        tree[node].max2 = {-1, 0};
        tree[node].max = {queryVal, 1};
        tree[node].sum = queryVal;
        return;
    }

    int mid = (l + r) / 2;
    if (queryPos <= mid) updatePos(l, mid, 2*node, queryPos, queryVal);
    else updatePos(mid + 1, r, 2*node + 1, queryPos, queryVal);
    tree[node] = combine(tree[2*node], tree[2*node + 1]);
}

int toRemove;
int searchedMax;
int findFirstPos(int l, int r, int node, int queryL, int queryR)
{
    if (l == r) 
    {
        assert(queryL <= l && r <= queryR);
        if (searchedMax != tree[node].max.value) return -1;
        if (toRemove != 1)
        {
            toRemove--;
            return -1;
        }

        return l;
    }

    if (queryL <= l & r <= queryR)
    {
        if (searchedMax != tree[node].max.value || tree[node].max.count < toRemove)
        {
            if (searchedMax == tree[node].max.value) toRemove -= tree[node].max.count;
            return -1;
        }

        int mid = (l + r) / 2;
        if (tree[node].max.value == tree[2*node].max.value && toRemove <= tree[2*node].max.count) 
        {
            return findFirstPos(l, mid, 2*node, queryL, queryR);
        }

        if (searchedMax == tree[2*node].max.value) toRemove -= tree[2*node].max.count;
        return findFirstPos(mid + 1, r, 2*node + 1, queryL, queryR);
    }

    int mid = (l + r) / 2;
    if (queryL <= mid)
    {
        int res = findFirstPos(l, mid, 2*node, queryL, queryR);
        if (res != -1) return res;
    }

    return findFirstPos(mid + 1, r, 2*node + 1, queryL, queryR);
}

void updateMAXminus(int l, int r, int node, int queryL, int queryR, int value)
{
    if (tree[node].max.value < searchedMax) return;
    if (queryL <= l && r <= queryR)
    {
        tree[node].sum -= 1LL * tree[node].max.count * value;
        tree[node].max.value -= value;
        if (l < r)
        {
            int mid = (l + r) / 2;
            updateMAXminus(l, mid, 2*node, queryL, queryR, value);
            updateMAXminus(mid + 1, r, 2*node + 1, queryL, queryR, value);
            tree[node] = combine(tree[2*node], tree[2*node + 1]);
        }

        return;
    }

    int mid = (l + r) / 2;
    if (queryL <= mid) updateMAXminus(l, mid, 2*node, queryL, queryR, value);
    if (mid + 1 <= queryR) updateMAXminus(mid + 1, r, 2*node + 1, queryL, queryR, value);
    tree[node] = combine(tree[2*node], tree[2*node + 1]);
}

Node original;
void updateMAX(int l, int r, int node, int queryL, int queryR)
{
    if (tree[node].max < original.max) return;
    if (queryL > l || r > queryR)
    {
        int mid = (l + r) / 2;
        if (queryL <= mid) updateMAX(l, mid, 2*node, queryL, queryR);
        if (mid + 1 <= queryR) updateMAX(mid + 1, r, 2*node + 1, queryL, queryR);
        tree[node] = combine(tree[2*node], tree[2*node + 1]);
        return;
    }

    if (tree[node].max != original.max || tree[node].max2 != original.max2)
    {
        if (tree[node].max == original.max)
        {
            tree[node].sum -= 1LL * (tree[node].max.value - original.max2.value) * tree[node].max.count;
            tree[node].max.value = original.max2.value;
            if (l != r)
            {
                int mid = (l + r) / 2;
                updateMAX(l, mid, 2*node, queryL, queryR);
                updateMAX(mid + 1, r, 2*node + 1, queryL, queryR);
                tree[node] = combine(tree[2*node], tree[2*node + 1]);
            }
        }

        return;
    }

    tree[node].sum -= 1LL * (tree[node].max.value - original.max2.value) * tree[node].max.count;
    tree[node].max.value = original.max2.value;
    if (tree[node].max2 == original.max2) tree[node].max.count += tree[node].max2.count;
    tree[node].max2 = {-1, 0};
    if (l != r)
    {
        int mid = (l + r) / 2;
        updateMAX(l, mid, 2*node, queryL, queryR);
        updateMAX(mid + 1, r, 2*node + 1, queryL, queryR);
        tree[node] = combine(tree[2*node], tree[2*node + 1]);
    }
}

void initialise(int N, int Q, int h[]) 
{
	n = N;
    q = Q;
    for (int i = 1 ; i <= n ; ++i)
    {
        a[i] = h[i];
    }

    build(1, n, 1);
}

void cut(int l, int r, int k) 
{
	while (k >= 1)
    {
        original = query(1, n, 1, l, r); 
        if (original.max.value == 0) return;
        if (original.max2.value != -1 && k >= 1LL * (original.max.value - original.max2.value) * original.max.count)
        {
            // std::cerr << "here :)\n";
            k -= 1LL * (original.max.value - original.max2.value) * original.max.count;
            updateMAX(1, n, 1, l, r);
            Node res = query(1, n, 1, l, r);
        } else
        {
            bool firstIf = false;
            if (original.max2.value == -1)
            {
                firstIf = true;
                original.max2.value = 0;
                if (k > 1LL * original.max.value * original.max.count)
                {
                    k = 1LL * original.max.value * original.max.count;
                }
            }
    
            if (firstIf || k <= 1LL * (original.max.value - original.max2.value - 1) * original.max.count)
            {
                searchedMax = original.max.value;
                updateMAXminus(1, n, 1, l, r, k / original.max.count);
                searchedMax = original.max.value - k / original.max.count;
                k %= original.max.count;
                if (k != 0)
                {
                    toRemove = k;
                    int firstPos = findFirstPos(1, n, 1, l, r);
                    updateMAXminus(1, n, 1, l, firstPos, 1);
                    k = 0;
                }
            } else
            {
                searchedMax = original.max.value;
                updateMAXminus(1, n, 1, l, r, (original.max.value - original.max2.value - 1));
                k %= original.max.count;
                if (k == 0) 
                {
                    k = original.max.count;
                }
 
                if (k != 0)
                {
                    toRemove = k;
                    searchedMax = original.max2.value + 1;
                    int firstPos = findFirstPos(1, n, 1, l, r);
                    updateMAX(1, n, 1, l, firstPos);
                }
                k = 0;
            }
        }
    }
}

void magic(int i, int x) 
{
	updatePos(1, n, 1, i, x);
}

llong inspect(int l, int r) 
{
	return query(1, n, 1, l, r).sum;
}

Compilation message

weirdtree.cpp: In function 'int findFirstPos(int, int, int, int, int)':
weirdtree.cpp:158:16: warning: suggest parentheses around comparison in operand of '&' [-Wparentheses]
  158 |     if (queryL <= l & r <= queryR)
      |         ~~~~~~~^~~~
weirdtree.cpp: In function 'void cut(int, int, int)':
weirdtree.cpp:277:18: warning: variable 'res' set but not used [-Wunused-but-set-variable]
  277 |             Node res = query(1, n, 1, l, r);
      |                  ^~~
# 결과 실행 시간 메모리 Grader output
1 Correct 31 ms 47308 KB Output is correct
2 Correct 30 ms 47316 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 31 ms 47308 KB Output is correct
2 Correct 30 ms 47316 KB Output is correct
3 Correct 149 ms 48320 KB Output is correct
4 Correct 219 ms 48284 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 40 ms 47308 KB Output is correct
2 Correct 50 ms 47324 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 40 ms 47308 KB Output is correct
2 Correct 50 ms 47324 KB Output is correct
3 Execution timed out 2075 ms 51828 KB Time limit exceeded
4 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 320 ms 51152 KB Output is correct
2 Execution timed out 2083 ms 51160 KB Time limit exceeded
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 31 ms 47308 KB Output is correct
2 Correct 30 ms 47316 KB Output is correct
3 Correct 149 ms 48320 KB Output is correct
4 Correct 219 ms 48284 KB Output is correct
5 Correct 40 ms 47308 KB Output is correct
6 Correct 50 ms 47324 KB Output is correct
7 Correct 146 ms 48264 KB Output is correct
8 Correct 208 ms 48368 KB Output is correct
9 Correct 146 ms 48304 KB Output is correct
10 Correct 224 ms 48352 KB Output is correct
11 Correct 153 ms 48332 KB Output is correct
12 Execution timed out 2088 ms 47948 KB Time limit exceeded
13 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 31 ms 47308 KB Output is correct
2 Correct 30 ms 47316 KB Output is correct
3 Correct 149 ms 48320 KB Output is correct
4 Correct 219 ms 48284 KB Output is correct
5 Correct 40 ms 47308 KB Output is correct
6 Correct 50 ms 47324 KB Output is correct
7 Execution timed out 2075 ms 51828 KB Time limit exceeded
8 Halted 0 ms 0 KB -