Submission #1074583

#TimeUsernameProblemLanguageResultExecution timeMemory
1074583shmax송신탑 (IOI22_towers)C++17
17 / 100
879 ms46256 KiB
#include "towers.h"
#include <bits/stdc++.h>

using namespace std;
#define bit(x, i) ((x >> i) & 1)
#define inf 2000'000'000LL
template<typename T>
using vec = vector<T>;


template<typename it>
struct SparseTable {

    using T = typename remove_reference<decltype(*declval<it>())>::type;
    vector<vector<T>> t;
    function<T(T, T)> f;
    vector<int> log;

    SparseTable() = default;

    SparseTable(it first, it last, function<T(T, T)> op) : t(1), f(op) {
        int n = distance(first, last);
        t.assign(32 - __builtin_clz(n), vector<T>(n));
        t[0].assign(first, last);
        log.resize(n + 1);
        for (int i = 2; i <= n; i++)
            log[i] = log[i / 2] + 1;

        for (int i = 1; i < t.size(); i++)
            for (int j = 0; j < n - (1 << i) + 1; j++)
                t[i][j] = f(t[i - 1][j], t[i - 1][j + (1 << (i - 1))]);
    }

    T get(int l, int r) {
        int h = log[r - l + 1];
        return f(t[h][l], t[h][r - (1 << h) + 1]);
    }
};


namespace STREE1 {
    struct node {
        int l, r;
        int sum;
    };
    vec<node> nodes;

    void fetch(int x) {
        int sm = 0;
        if (nodes[x].l != -1) {
            sm += nodes[nodes[x].l].sum;
        }
        if (nodes[x].r != -1) {
            sm += nodes[nodes[x].r].sum;
        }
        nodes[x].sum = sm;
    }

    int new_node() {
        nodes.push_back({-1, -1, 0});
        return nodes.size() - 1;
    }

    vec<int> roots;

    int build(int tl, int tr) {
        if (tl == tr) {
            int id = new_node();
            nodes[id].sum = 0;
            return id;
        }
        int id = new_node();
        int tm = (tl + tr) / 2;
        int l = build(tl, tm);
        int r = build(tm + 1, tr);
        nodes[id].l = l;
        nodes[id].r = r;
        fetch(id);
        return id;
    }

    int get(int v, int tl, int tr, int l, int r) {
        if (tl == l and tr == r) {
            return nodes[v].sum;
        }
        int tm = (tl + tr) / 2;
        if (r <= tm) {
            return get(nodes[v].l, tl, tm, l, r);
        }
        if (l > tm) {
            return get(nodes[v].r, tm + 1, tr, l, r);
        }
        return get(nodes[v].l, tl, tm, l, tm) + get(nodes[v].r, tm + 1, tr, tm + 1, r);
    }

    int update(int v, int tl, int tr, int pos, int x) {
        if (tl == tr) {
            int id = new_node();
            nodes[id].sum = x;
            return id;
        }
        int tm = (tl + tr) / 2;
        int id = new_node();
        if (pos <= tm) {
            int tid = update(nodes[v].l, tl, tm, pos, x);
            nodes[id].l = tid;
            nodes[id].r = nodes[v].r;
        } else {
            int tid = update(nodes[v].r, tm + 1, tr, pos, x);
            nodes[id].l = nodes[v].l;
            nodes[id].r = tid;
        }
        fetch(id);
        return id;
    }

    int find_left(int v, int tl, int tr, int l, int r) {
        if (nodes[v].sum == 0) return -1;
        if (tl == tr) {
            return tl;
        }
        int tm = (tl + tr) / 2;
        if (r <= tm) {
            return find_left(nodes[v].l, tl, tm, l, r);
        }
        if (l > tm) {
            return find_left(nodes[v].r, tm + 1, tr, l, r);
        }
        int res = find_left(nodes[v].l, tl, tm, l, tm);
        if (res == -1) {
            res = find_left(nodes[v].r, tm + 1, tr, tm + 1, r);
        }
        return res;
    }


    int find_right(int v, int tl, int tr, int l, int r) {
        if (nodes[v].sum == 0) return -1;
        if (tl == tr) {
            return tl;
        }
        int tm = (tl + tr) / 2;
        if (r <= tm) {
            return find_right(nodes[v].l, tl, tm, l, r);
        }
        if (l > tm) {
            return find_right(nodes[v].r, tm + 1, tr, l, r);
        }
        int res = find_right(nodes[v].r, tm + 1, tr, tm + 1, r);
        if (res == -1) {
            res = find_right(nodes[v].l, tl, tm, l, tm);
        }
        return res;
    }
}

namespace STREE2 {
    const int maxN = 1e5 + 5;
    struct node {
        int min, max;
        int min_diff;
        int max_diff;
    };

    node combine(node a, node b) {
        node res;
        res.min = min(a.min, b.min);
        res.max = max(a.max, b.max);
        res.min_diff = min(a.min_diff, b.min_diff);
        res.max_diff = max(a.max_diff, b.max_diff);
        res.min_diff = min(res.min_diff, b.min - a.max);
        res.max_diff = max(res.max_diff, b.max - a.min);
        return res;
    }

    node stree[4 * maxN];

    void build(int v, int tl, int tr, vec<int> &a) {
        if (tl == tr) {
            stree[v] = {a[tl], a[tl], 0, 0};
            return;
        }
        int tm = (tl + tr) / 2;
        build(2 * v, tl, tm, a);
        build(2 * v + 1, tm + 1, tr, a);
        stree[v] = combine(stree[2 * v], stree[2 * v + 1]);
    }

    node get(int v, int tl, int tr, int l, int r) {
        if (tl == l and tr == r) {
            return stree[v];
        }
        int tm = (tl + tr) / 2;
        if (r <= tm) {
            return get(2 * v, tl, tm, l, r);
        }
        if (l > tm) {
            return get(2 * v + 1, tm + 1, tr, l, r);
        }
        return combine(get(2 * v, tl, tm, l, tm), get(2 * v + 1, tm + 1, tr, tm + 1, r));
    }
}

vec<int> h;
int n;
vec<pair<int, int>> deltas;
SparseTable<vec<int>::iterator> stmin;
SparseTable<vec<int>::iterator> stmax;

void init(int N, std::vector<int> H) {
    n = N;
    h = H;
    stmin = SparseTable(h.begin(), h.end(), [](int a, int b) { return min(a, b); });
    stmax = SparseTable(h.begin(), h.end(), [](int a, int b) { return max(a, b); });
    for (int i = 0; i < n; i++) {
        int h1 = inf;
        if (i != 0) {
            if (stmin.get(0, i - 1) < h[i]) {
                int tl = 0;
                int tr = i - 1;
                while (tl != tr) {
                    int tm = (tl + tr + 1) / 2;
                    if (stmin.get(tm, i - 1) < h[i]) {
                        tl = tm;
                    } else {
                        tr = tm - 1;
                    }
                }
                h1 = stmax.get(tl, i - 1);
            }
        }
        int h2 = inf;
        if (i != n - 1) {
            if (stmin.get(i + 1, n - 1) < h[i]) {
                int tl = i + 1;
                int tr = n - 1;
                while (tl != tr) {
                    int tm = (tl + tr) / 2;
                    if (stmin.get(i + 1, tm) < h[i]) {
                        tr = tm;
                    } else {
                        tl = tm + 1;
                    }
                }
                h2 = stmax.get(i + 1, tr);
            }
        }
        int max_delta = min(h1, h2) - h[i];
        deltas.push_back({max_delta, i});
    }

    sort(deltas.rbegin(), deltas.rend());
    int t = STREE1::build(0, n - 1);
    STREE1::roots.push_back(t);
    for (auto [d, i]: deltas) {
        t = STREE1::update(STREE1::roots.back(), 0, n - 1, i, 1);
        STREE1::roots.push_back(t);
    }
    STREE2::build(1, 0, n - 1, h);
}

int max_towers(int L, int R, int D) {
    int tl = 0;
    int tr = n - 1;
    while (tl != tr) {
        int tm = (tl + tr + 1) / 2;
        if (deltas[tm].first >= D) {
            tl = tm;
        } else {
            tr = tm - 1;
        }
    }
    if (deltas[0].first < D)
        tl = -1;
    int vr = STREE1::roots[tl + 1];
    int ans = STREE1::get(vr, 0, n - 1, L, R);
    if (ans == 0)
        return 1;
    int pl = STREE1::find_left(vr, 0, n - 1, L, R);
    int pr = STREE1::find_right(vr, 0, n - 1, L, R);
    if (pl > L + 1) {
        tl = L;
        tr = pl - 1;
        while (tl != tr) {
            int tm = (tl + tr + 1) / 2;
            if (stmax.get(tm, pl - 1) + D >= h[pl]) {
                tl = tm;
            } else {
                tr = tm - 1;
            }
        }
        if (STREE2::get(1, 0, n - 1, L, tl).max_diff >= D)
            ans++;
    }
    if (pr < R - 1) {
        tl = pr + 1;
        tr = R;
        while (tl != tr) {
            int tm = (tl + tr) / 2;
            if (stmax.get(pr + 1, tm) + D >= h[pr]) {
                tr = tm;
            } else {
                tl = tm + 1;
            }
        }
        if (STREE2::get(1, 0, n - 1, tl, R).min_diff <= -D)
            ans++;
    }
    return ans;
}

Compilation message (stderr)

towers.cpp: In instantiation of 'SparseTable<it>::SparseTable(it, it, std::function<typename std::remove_reference<decltype (* declval<it>())>::type(typename std::remove_reference<decltype (* declval<it>())>::type, typename std::remove_reference<decltype (* declval<it>())>::type)>) [with it = __gnu_cxx::__normal_iterator<int*, std::vector<int> >; typename std::remove_reference<decltype (* declval<it>())>::type = int]':
towers.cpp:213:83:   required from here
towers.cpp:29:27: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::vector<int>, std::allocator<std::vector<int> > >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   29 |         for (int i = 1; i < t.size(); i++)
      |                         ~~^~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...