답안 #307786

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
307786 2020-09-29T10:08:41 Z VEGAnn 식물 비교 (IOI20_plants) C++17
컴파일 오류
0 ms 0 KB
#include "plants.h"
#include <bits/stdc++.h>
#define PB push_back
#define sz(x) ((int)x.size())
#define i2 array<int,2>
using namespace std;
const int N = 200100;
const int oo = 2e9;
const int PW = 20;
set<int> locs;
set<int>::iterator iter;
set<i2> mx;
vector<int> vc, glob;
int per[N], psh[4 * N], n, nt[2][N], up[2][N][PW], k;
i2 st[4 * N], mn[4 * N];

void build(int v, int l, int r){
    if (l == r){
        st[v] = {glob[l], l};
        return;
    }

    int md = (l + r) >> 1;

    build(v + v, l, md);
    build(v + v + 1, md + 1, r);

    st[v] = min(st[v + v], st[v + v + 1]);
}

void push(int v){
    if (psh[v] != 0){
        st[v][0] += psh[v];

        if (v + v + 1 < 4 * N){
            psh[v + v] += psh[v];
            psh[1 + v + v] += psh[v];
        }

        psh[v] = 0;
    }
}

void del(int v, int tl, int tr, int l, int r, int vl){
    push(v);

    if (l > r) return;

    if (tl == l && tr == r) {
        psh[v] = vl;
        push(v);
        return;
    }

    int md = (tl + tr) >> 1;

    del(v + v, tl, md, l, min(r, md), vl);
    del(v + v + 1, md + 1, tr, max(md + 1, l), r, vl);

    st[v] = min(st[v + v], st[v + v + 1]);
}

i2 get(int v, int tl, int tr, int l, int r){
    push(v);

    if (l > r) return {oo, 0};

    if (tl == l && tr == r) return st[v];

    int md = (tl + tr) >> 1;

    return min(get(v + v, tl, md, l, min(r, md)),
               get(v + v + 1, md + 1, tr, max(md + 1, l), r));
}

int dist(int fi, int se){
    if (se > fi)
        return se - fi;
    else return se - fi + n;
}

void uhodi(int vl){
    if (sz(locs) <= 2){
        locs.erase(vl);
        mx.clear();
    } else {
        iter = locs.find(vl);

        int PREV, NEXT;

        if (iter == locs.begin())
            PREV = (*(--locs.end()));
        else PREV = (*prev(iter));

        if (next(iter) == locs.end())
            NEXT = (*locs.begin());
        else NEXT = (*next(iter));

        mx.erase({dist(PREV, vl), vl});
        mx.erase({dist(vl, NEXT), NEXT});

        locs.erase(iter);

        mx.insert({dist(PREV, NEXT), NEXT});
    }
}

void prihodi(int vl){
    if (sz(locs) == 0){
        locs.insert(vl);
        return;
    }

    locs.insert(vl);

    iter = locs.find(vl);

    int PREV, NEXT;

    if (iter == locs.begin())
        PREV = (*(--locs.end()));
    else PREV = (*prev(iter));

    if (next(iter) == locs.end())
        NEXT = (*locs.begin());
    else NEXT = (*next(iter));

    mx.erase({dist(PREV, NEXT), NEXT});

    mx.insert({dist(PREV, vl), vl});
    mx.insert({dist(vl, NEXT), NEXT});
}

void BUILD(int v, int l, int r){
    mn[v] = {oo, l};

    if (l == r) return;

    int md = (l + r) >> 1;

    BUILD(v + v, l, md);
    BUILD(v + v + 1, md + 1, r);
}

void update_mn(int v, int l, int r, int ps, int vl){
    if (l == r){
        mn[v] = {vl, l};
        return;
    }

    int md = (l + r) >> 1;

    if (ps <= md)
        update_mn(v + v, l, md, ps, vl);
    else update_mn(v + v + 1, md + 1, r, ps, vl);

    mn[v] = min(mn[v + v], mn[v + v + 1]);
}

i2 get_mn(int v, int tl, int tr, int l, int r){
    if (l > r) return {oo, 0};

    if (tl == l & tr == r) return mn[v];

    int md = (tl + tr) >> 1;

    return min(get_mn(v + v, tl, md, l, min(r, md)),
               get_mn(v + v + 1, md + 1, tr, max(md + 1, l), r));
}

void init(int K, vector<int> r) {
    n = sz(r);
    k = K;

    glob = r;

    build(1, 0, n - 1);

    vc.clear();

    for (int i = 0; i < n; i++)
        if (r[i] == 0)
            vc.PB(i);

    for (int it = 0; it < sz(vc); it++){
        int cur = vc[it];
        int nxt = vc[(it + 1) % sz(vc)];

        locs.insert(cur);

        del(1, 0, n - 1, cur, cur, oo);

        if (sz(vc) == 1) continue;

        if (cur == vc.back()){
            mx.insert({nxt + n - cur, nxt});
        } else {
            mx.insert({nxt - cur, nxt});
        }
    }

    BUILD(1, 0, n - 1);

    for (int it = n - 1; it >= 0; it--){
        assert(sz(locs) > 0);

        int cand = -1;

        if (sz(locs) == 1){
            int vl = (*locs.begin());

            cand = vl;

            uhodi(vl);
        } else {
            cand = (*(--mx.end()))[1];

            uhodi(cand);
        }

        // delete with segment tree

        int rt = cand, lf = cand - k + 1;

        per[cand] = it;

        if (lf < 0){
            lf += n;

            del(1, 0, n - 1, 0, rt, -1);

            i2 cur = get(1, 0, n - 1, 0, rt);

            while (cur[0] == 0){
                prihodi(cur[1]);
                del(1, 0, n - 1, cur[1], cur[1], oo);
                cur = get(1, 0, n - 1, 0, rt);
            }

            del(1, 0, n - 1, lf, n - 1,  -1);

            cur = get(1, 0, n - 1, lf, n - 1);

            while (cur[0] == 0){
                prihodi(cur[1]);
                del(1, 0, n - 1, cur[1], cur[1], oo);
                cur = get(1, 0, n - 1, lf, n - 1);
            }

            cur = min(get_mn(1, 0, n - 1, 0, rt), get_mn(1, 0, n - 1, lf, n - 1));

            if (cur[0] == oo)
                nt[0][cand] = cand;
            else nt[0][cand] = cur[1];
        } else {
            del(1, 0, n - 1, lf, rt, -1);

            i2 cur = get(1, 0, n - 1, lf, rt);

            while (cur[0] == 0){
                prihodi(cur[1]);
                del(1, 0, n - 1, cur[1], cur[1], oo);
                cur = get(1, 0, n - 1, lf, rt);
            }

            cur = get_mn(1, 0, n - 1, lf, rt);

            if (cur[0] == oo)
                nt[0][cand] = cand;
            else nt[0][cand] = cur[1];
        }

        lf = cand; rt = cand + k - 1;

        if (rt >= n){
            rt -= n;

            i2 cur = min(get_mn(1, 0, n - 1, 0, rt), get_mn(1, 0, n - 1, lf, n - 1));

            if (cur[0] == oo)
                nt[1][cand] = cand;
            else nt[1][cand] = cur[1];
        } else {
            i2 cur = get_mn(1, 0, n - 1, lf, rt);

            if (cur[0] == oo)
                nt[1][cand] = cand;
            else nt[1][cand] = cur[1];
        }

        update_mn(1, 0, n - 1, cand, per[cand]);
    }

    for (int i = 0; i < n; i++){
        up[0][i][0] = nt[0][i];
        up[1][i][0] = nt[1][i];
    }

    for (int po = 1; po < PW; po++)
    for (int i = 0; i < n; i++){
        up[0][i][po] = up[0][up[0][i][po - 1]][po - 1];
        up[1][i][po] = up[1][up[1][i][po - 1]][po - 1];
    }

    return;
}

bool smaller(int x, int y){
    // left direction

    int loc = x;

    for (int po = PW - 1; po >= 0; po--){
        int nw = up[0][loc][po];

        int dst = nw - y;

        if (dst < 0) dst += n;

        if (dst < k) continue;

        if (x > y){
            if (nw > y && nw <= x)
                loc = nw;
        } else {
            if (nw <= x || nw > y)
                loc = nw;
        }
    }

    int dst = loc - y;

    if (dst < 0) dst += n;

    if (dst < k && per[loc] < per[y])
        return 1;

    loc = up[0][loc][0];

    dst = loc - y;

    if (dst < 0) dst += n;

    if (x > y){
        if ((loc > y && loc <= x) && dst < k && per[loc] < per[y])
            return 1;
    } else {
        if ((loc <= x || loc > y) && dst < k && per[loc] < per[y])
            return 1;
    }

    //right direction

    loc = x;

    for (int po = PW - 1; po >= 0; po--){
        int nw = up[1][loc][po];

        int dst = loc - y;

        if (dst < 0) dst += n;

        if (dst < k) continue;

        if (x < y){
            if (nw < y && nw >= x)
                loc = nw;
        } else {
            if (nw >= x || nw < y)
                loc = nw;
        }
    }

    dst = y - loc;

    if (dst < 0) dst += n;

    if (dst < k && per[loc] < per[y])
        return 1;

    loc = up[1][loc][0];

    dst = loc - y;

    if (dst < 0) dst += n;

    if (x < y){
        if ((loc < y && loc >= x) && dst < k && per[loc] < per[y])
            return 1
    } else {
        if ((loc >= x || loc < y) && dst < k && per[loc] < per[y])
            return 1;
    }

    return 0;
}

int compare_plants(int x, int y) {
    if (smaller(x, y)) return -1;
    if (smaller(y, x)) return +1;

    return 0;
}

Compilation message

plants.cpp: In function 'std::array<int, 2> get_mn(int, int, int, int, int)':
plants.cpp:163:12: warning: suggest parentheses around comparison in operand of '&' [-Wparentheses]
  163 |     if (tl == l & tr == r) return mn[v];
      |         ~~~^~~~
plants.cpp: In function 'bool smaller(int, int)':
plants.cpp:389:21: error: expected ';' before '}' token
  389 |             return 1
      |                     ^
      |                     ;
  390 |     } else {
      |     ~