제출 #649150

#제출 시각아이디문제언어결과실행 시간메모리
649150boris_mihovWeirdtree (RMI21_weirdtree)C++17
13 / 100
214 ms58280 KiB
#include "weirdtree.h"
#include <algorithm>
#include <iostream>
#include <numeric>
#include <cassert>
#include <vector>

typedef long long llong;
const int MAXN = 300000 + 10;
const int INF  = 2e9;

int n, q;
int a[MAXN];
struct Number
{
    llong value, count;
    Number()
    {
        value = -1;
        count = 0;
    }

    Number(llong _value, llong _count)
    {
        value = _value;
        count = _count;
    }

    inline friend bool operator < (Number a, Number b)
    {
        return a.value < b.value;
    }

    inline friend bool operator == (Number a, Number b)
    {
        return a.value == b.value;
    }

    inline friend bool operator != (Number a, Number b)
    {
        return !(a == b);
    }

    inline friend bool operator > (Number a, Number b)
    {
        return a.value > b.value;
    }

    inline friend void operator += (Number &a, Number b)
    {
        if (a.value == b.value) a.count += b.count;
        if (a.value < b.value) a = b;
    }
};

struct Node
{
    Number max, max2;  
    llong sum;
    Node()
    {
        max = max2 = {-1, 0};
        sum = 0;
    }

} tree[4*MAXN];

Node combine(Node left, Node right)
{
    Node result;
    result.sum = left.sum + right.sum;
    
    result.max = left.max;
    result.max += right.max;

    if (left.max != result.max)
    {
        result.max2 = left.max;
    } else
    {
        result.max2 = left.max2;
    }

    if (right.max != result.max)
    {
        result.max2 += right.max;
    } else
    {
        result.max2 += right.max2;
    }

    return result;
}

void build(int l, int r, int node)
{
    if (l == r)
    {
        tree[node].max2 = {-1, 0};
        tree[node].max = {a[l], 1};
        tree[node].sum = a[l];
        return;
    }

    int mid = (l + r) / 2;
    build(l, mid, 2*node);
    build(mid + 1, r, 2*node + 1);
    tree[node] = combine(tree[2*node], tree[2*node + 1]);
}

Node query(int l, int r, int node, int queryL, int queryR)
{
    if (queryL <= l && r <= queryR)
    {
        return tree[node];
    }

    Node res;
    int mid = (l + r) / 2;
    if (queryL <= mid) res = combine(res, query(l, mid, 2*node, queryL, queryR));
    if (mid + 1 <= queryR) res = combine(res, query(mid + 1, r, 2*node + 1, queryL, queryR));
    return res;
}

void updatePos(int l, int r, int node, int queryPos, int queryVal)
{
    if (l == r)
    {
        tree[node].max2 = {-1, 0};
        tree[node].max = {queryVal, 1};
        tree[node].sum = queryVal;
        return;
    }

    int mid = (l + r) / 2;
    if (queryPos <= mid) updatePos(l, mid, 2*node, queryPos, queryVal);
    else updatePos(mid + 1, r, 2*node + 1, queryPos, queryVal);
    tree[node] = combine(tree[2*node], tree[2*node + 1]);
}

int toRemove;
int searchedMax;
int findFirstPos(int l, int r, int node, int queryL, int queryR)
{
    if (l == r) 
    {
        if (searchedMax != tree[node].max.value) return -1;
        if (toRemove != 1)
        {
            toRemove--;
            return -1;
        }

        return l;
    }

    if (queryL <= l & r <= queryR)
    {
        if (searchedMax != tree[node].max.value || tree[node].max.count < toRemove)
        {
            if (searchedMax == tree[node].max.value) toRemove -= tree[node].max.count;
            return -1;
        }

        int mid = (l + r) / 2;
        if (tree[node].max.value == tree[2*node].max.value && toRemove <= tree[2*node].max.count) 
        {
            return findFirstPos(l, mid, 2*node, queryL, queryR);
        }

        if (searchedMax == tree[2*node].max.value) toRemove -= tree[2*node].max.count;
        return findFirstPos(mid + 1, r, 2*node + 1, queryL, queryR);
    }

    int mid = (l + r) / 2;
    if (queryL <= mid)
    {
        int res = findFirstPos(l, mid, 2*node, queryL, queryR);
        if (res != -1) return res;
    }

    return findFirstPos(mid + 1, r, 2*node + 1, queryL, queryR);
}

void updateMAXminus(int l, int r, int node, int queryL, int queryR, int value)
{
    if (tree[node].max.value < searchedMax) return;
    if (queryL <= l && r <= queryR)
    {
        tree[node].sum -= 1LL * tree[node].max.count * value;
        tree[node].max.value -= value;
        if (l < r)
        {
            int mid = (l + r) / 2;
            updateMAXminus(l, mid, 2*node, queryL, queryR, value);
            updateMAXminus(mid + 1, r, 2*node + 1, queryL, queryR, value);
            tree[node] = combine(tree[2*node], tree[2*node + 1]);
        }

        return;
    }

    int mid = (l + r) / 2;
    if (queryL <= mid) updateMAXminus(l, mid, 2*node, queryL, queryR, value);
    if (mid + 1 <= queryR) updateMAXminus(mid + 1, r, 2*node + 1, queryL, queryR, value);
    tree[node] = combine(tree[2*node], tree[2*node + 1]);
}

Node original;
void updateMAX(int l, int r, int node, int queryL, int queryR)
{
    if (queryL > l || r > queryR)
    {
        int mid = (l + r) / 2;
        if (queryL <= mid) updateMAX(l, mid, 2*node, queryL, queryR);
        if (mid + 1 <= queryR) updateMAX(mid + 1, r, 2*node + 1, queryL, queryR);
        tree[node] = combine(tree[2*node], tree[2*node + 1]);
        return;
    }

    if (tree[node].max != original.max || tree[node].max2 != original.max2)
    {
        if (tree[node].max == original.max)
        {
            tree[node].sum -= 1LL * (tree[node].max.value - original.max2.value) * tree[node].max.count;
            tree[node].max.value = original.max2.value;
            if (l != r)
            {
                int mid = (l + r) / 2;
                updateMAX(l, mid, 2*node, queryL, queryR);
                updateMAX(mid + 1, r, 2*node + 1, queryL, queryR);
                tree[node] = combine(tree[2*node], tree[2*node + 1]);
            }
        }

        return;
    }

    tree[node].sum -= 1LL * (tree[node].max.value - original.max2.value) * tree[node].max.count;
    tree[node].max.value = original.max2.value;
    if (tree[node].max2 == original.max2) tree[node].max.count += tree[node].max2.count;
    tree[node].max2 = {-1, 0};
    if (l != r)
    {
        int mid = (l + r) / 2;
        updateMAX(l, mid, 2*node, queryL, queryR);
        updateMAX(mid + 1, r, 2*node + 1, queryL, queryR);
        tree[node] = combine(tree[2*node], tree[2*node + 1]);
    }
}

void initialise(int N, int Q, int h[]) 
{
	n = N;
    q = Q;
    for (int i = 1 ; i <= n ; ++i)
    {
        a[i] = h[i];
    }

    build(1, n, 1);
}

void cut(int l, int r, int k) 
{
    original = query(1, n, 1, l, r); 
    if (original.max.value == 0) return;
    toRemove = 1;
    searchedMax = original.max.value;
    int firstPos = findFirstPos(1, n, 1, l, r);
    updateMAXminus(1, n, 1, firstPos, firstPos, 1);
	// while (k >= 1)
    // {
    //     original = query(1, n, 1, l, r); 
    //     if (original.max2.value != -1 && k >= 1LL * (original.max.value - original.max2.value) * original.max.count)
    //     {
    //         k -= 1LL * (original.max.value - original.max2.value) * original.max.count;
    //         updateMAX(1, n, 1, l, r);
    //     } else
    //     {
    //         bool firstIf = false;
    //         if (original.max2.value == -1)
    //         {
    //             firstIf = true;
    //             original.max2.value = 0;
    //             if (k > 1LL * original.max.value * original.max.count)
    //             {
    //                 k = 1LL * original.max.value * original.max.count;
    //             }
    //         }

    //         if (firstIf || k <= 1LL * (original.max.value - original.max2.value - 1) * original.max.count)
    //         {
    //             searchedMax = original.max.value;
    //             updateMAXminus(1, n, 1, l, r, k / original.max.count);
    //             searchedMax = original.max.value - k / original.max.count;
    //             k %= original.max.count;
    //             if (k != 0)
    //             {
    //                 toRemove = k;
    //                 int firstPos = findFirstPos(1, n, 1, l, r);
    //                 updateMAXminus(1, n, 1, l, firstPos, 1);
    //                 k = 0;
    //             }
    //         } else
    //         {
    //             searchedMax = original.max.value;
    //             updateMAXminus(1, n, 1, l, r, (original.max.value - original.max2.value - 1));
    //             k %= original.max.count;
    //             toRemove = k;
    //             searchedMax = original.max2.value + 1;
    //             int firstPos = findFirstPos(1, n, 1, l, r);
    //             original = query(1, n, 1, l, firstPos);
    //             updateMAX(1, n, 1, l, firstPos);
    //             k = 0;
    //         }
    //     }
    // }
}

void magic(int i, int x) 
{
	updatePos(1, n, 1, i, x);
}

llong inspect(int l, int r) 
{
	return query(1, n, 1, l, r).sum;
}

컴파일 시 표준 에러 (stderr) 메시지

weirdtree.cpp: In function 'int findFirstPos(int, int, int, int, int)':
weirdtree.cpp:157:16: warning: suggest parentheses around comparison in operand of '&' [-Wparentheses]
  157 |     if (queryL <= l & r <= queryR)
      |         ~~~~~~~^~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...