답안 #687289

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
687289 2023-01-26T09:12:36 Z 79brue 송신탑 (IOI22_towers) C++17
0 / 100
1333 ms 308756 KB
#include "towers.h"
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

int n;
int arr[100002];

void calculateIntervals();

void init(int N, vector<int> H){
    n = N;
    for(int i=0; i<n; i++) arr[i] = H[i];

    calculateIntervals();
}

/** PST 구조체
 *  1. 구간 합 쿼리 처리 기능
 *  2. 어떤 pair (x, y)가 들어올 때, 지금까지 들어온 pair의 max 처리 기능
 */
struct PST{
    struct Node{
        Node *lc, *rc;
        int sum;
        pair<int, int> lPair, rPair;

        Node(Node* node){
            lc = node->lc, rc = node->rc;
            sum = node->sum;
            lPair = node->lPair, rPair = node->rPair;
        }

        Node(int l, int r){
            lc = rc = nullptr;
            sum = 0;
            lPair = rPair = make_pair(-1, -1);

            if(l==r) return;
            int m = (l+r)>>1;
            lc = new Node(l, m);
            rc = new Node(m+1, r);
        }

        Node* update(int l, int r, int x, int v){
            if(l==r){
                Node *node = new Node(this);
                node->sum += v;
                return node;
            }
            int m = (l+r)>>1;
            Node *node = new Node(this);
            if(x<=m){
                node->lc = lc->update(l, m, x, v);
                node->sum += v;
                return node;
            }
            else{
                node->rc = rc->update(m+1, r, x, v);
                node->sum += v;
                return node;
            }
        }

        Node* updateL(int l, int r, int s, int e, pair<int, int> p){
            if(r<s || e<l) return this;
            if(s<=l && r<=e){
                Node *node = new Node(this);
                node->lPair = p;
                return node;
            }
            int m = (l+r)>>1;
            Node *node = new Node(this);
            node->lc = lc->updateL(l, m, s, e, p);
            node->rc = rc->updateL(m+1, r, s, e, p);
            return node;
        }

        Node* updateR(int l, int r, int s, int e, pair<int, int> p){
            if(r<s || e<l) return this;
            if(s<=l && r<=e){
                Node *node = new Node(this);
                node->rPair = p;
                return node;
            }
            int m = (l+r)>>1;
            Node *node = new Node(this);
            node->lc = lc->updateR(l, m, s, e, p);
            node->rc = rc->updateR(m+1, r, s, e, p);
            return node;
        }

        int query(int l, int r, int s, int e){
            if(r<s || e<l) return 0;
            if(s<=l && r<=e) return sum;
            int m = (l+r)>>1;
            return lc->query(l, m, s, e) + rc->query(m+1, r, s, e);
        }

        pair<int, int> queryL(int l, int r, int x){
            if(l==r) return lPair;
            int m = (l+r)>>1;
            if(x<=m) return max(lPair, lc->queryL(l, m, x));
            else return max(lPair, rc->queryL(m+1, r, x));
        }

        pair<int, int> queryR(int l, int r, int x){
            if(l==r) return rPair;
            int m = (l+r)>>1;
            if(x<=m) return max(rPair, lc->queryR(l, m, x));
            else return max(rPair, rc->queryR(m+1, r, x));
        }
    } *root;
    int n, curD;
    Node *history[200002];

    PST(){}

    void init(int N){
        n = N;
        root = new Node(0, n-1);
        curD = 0;
        memset(history, 0, sizeof(history));
    }

    void nextD(){
        history[curD++] = root;
    }

    void update(int x, int v){
        root = root->update(0, n-1, x, v);
    }

    void updateL(int l, int r, int idx){
        root = root->updateL(0, n-1, l, r, make_pair(curD, idx));
    }

    void updateR(int l, int r, int idx){
        root = root->updateR(0, n-1, l, r, make_pair(curD, idx));
    }

    vector<int> query(int l, int r, int d){
        /// 정수 3개의 배열을 리턴
        vector<int> ret (3);
        ret[0] = history[d]->query(0, n-1, l, r);
        ret[1] = history[d]->queryL(0, n-1, l).second;
        ret[2] = history[d]->queryR(0, n-1, r).second;

        return ret;
    }

    int queryL(int x, int d){
        return history[d]->queryL(0, n-1, x).second;
    }

    int queryR(int x, int d){
        return history[d]->queryR(0, n-1, x).second;
    }
} tree;

struct Interval{
    int s, e, diff, idx;
    Interval(int s, int e, int diff, int idx): s(s), e(e), diff(diff), idx(idx){}
};

struct xOrder {
    bool operator() (const Interval &l, const Interval &r)const{
        return l.s < r.s;
    }
};

struct diffOrder {
    bool operator() (const Interval &l, const Interval &r)const{
        if(abs(l.diff) != abs(r.diff)) return abs(l.diff) < abs(r.diff);
        return l.s < r.s;
    }
};

set<Interval, xOrder> xSet;
set<Interval, diffOrder> diffSet;

int sgn(int x){
    return x > 0 ? 1 : x == 0 ? 0 : -1;
}

vector<int> diffScale;
vector<Interval> intervals;
int leftInt[200002], rightInt[200002];

void makeInitialIntervals(){ /// 초기의 구간들을 계산한다
    tree.init(n);

    int prv = 0;
    while(prv < n-1){
        int j = prv;
        while(j+1<n && sgn(arr[j] - arr[prv]) * sgn(arr[j+1] - arr[j]) >= 0) j++;
        if(arr[j] == arr[prv]) break;

        /// 현재 찾은 구간을 셋에 집어넣는다
        int idx = (int)intervals.size();
        Interval tmp = Interval(prv, j, arr[j] - arr[prv], idx);
        intervals.push_back(tmp);
        xSet.insert(tmp);
        diffSet.insert(tmp);

        /// 현재 찾은 구간을 트리에 업데이트한다
        tree.update(prv, 1);
        tree.updateL(prv+1, j, idx);
        tree.updateR(prv, j-1, idx);

        prv = j;
    }
    diffScale.push_back(0);
    leftInt[tree.curD] = (xSet.empty() ? -1 : xSet.begin()->idx);
    rightInt[tree.curD] = (xSet.empty() ? -1 : xSet.rbegin()->idx);
    tree.nextD();
}

void calculateChanges(){ /// D가 커짐에 따라 생기는 변화들을 구한다
    while((int)diffSet.size()){
        int D = abs(diffSet.begin()->diff); /// 현재의 차이값
        while(abs((int)diffSet.begin()->diff) == D){ /// 해당하는 차이의 모든 구간을 처리한다.
            /// 구간을 뽑는다
            Interval p = *diffSet.begin();
            diffSet.erase(diffSet.begin());

            /// Case 1. 직전에 구간이 존재하지 않음
            if(xSet.begin()->s == p.s){ /// 이 구간은 그냥 삭제한다.
                xSet.erase(xSet.begin());
                tree.updateL(0, p.e, -1);
            }
            /// Case 2. 직후에 구간이 존재하지 않음
            else if(xSet.rbegin()->s == p.s){ /// 이 구간 역시 그냥 삭제한다.
                xSet.erase(prev(xSet.end()));
                tree.updateR(p.s, n-1, -1);
            }
            /// Case 3. 양쪽 모두 잘 존재함
            else{ /// 새로운 구간을 만들어 준다.
                auto it = xSet.find(p);
                int ns = prev(it)->s, ne = next(it)->e;
                int idx = (int)intervals.size();
                Interval np = Interval(ns, ne, arr[ne] - arr[ns], idx);
                intervals.push_back(np);

                /// 셋에 반영한다.
                diffSet.erase(*prev(it)), diffSet.erase(*next(it));
                xSet.erase(prev(it)), xSet.erase(next(it));
                xSet.erase(it);
                xSet.insert(np), diffSet.insert(np);

                /// 트리에 반영한다
                tree.update(p.s, -1);
                tree.update(p.e, -1);
                tree.updateL(np.s+1, np.e, idx);
                tree.updateR(np.s, np.e-1, idx);
            }
        }

        diffScale.push_back(D+1);
        leftInt[tree.curD] = (xSet.empty() ? -1 : xSet.begin()->idx);
        rightInt[tree.curD] = (xSet.empty() ? -1 : xSet.rbegin()->idx);
        tree.nextD();
    }
    leftInt[tree.curD] = rightInt[tree.curD] = -1;
}

void calculateIntervals(){
    makeInitialIntervals();
    calculateChanges();
}

int max_towers(int L, int R, int D){
    D = upper_bound(diffScale.begin(), diffScale.end(), D) - diffScale.begin() - 1;
    if(leftInt[D] == -1) return 1; /// 현 시기에 구간이 하나도 없는 경우

    vector<int> response = tree.query(L, R, D);
    int sum = response[0], ql = response[1], qr = response[2];

    /// 구간 시작점이 하나라도 있는 경우
    if(sum){
        /// 1. 완전히 속하는 구간의 개수를 센다.
        /// 왼쪽 끝 부분은 예외가 생기지 않는다.
        /// 오른쪽 끝 부분은 맨 오른쪽 점이 어느 구간에 포함되지 않았을 때를 제외하고는 1을 제외해야 한다.
        int insideIntervalCount = sum;
        if(qr != -1) insideIntervalCount--; /// 시작점이

        /// 2. 왼쪽 끝을 본다.
        /// 위에서 본 구간 중 맨 왼쪽 구간을 찾는다.
        int lint = (ql == -1) ? leftInt[D] : tree.queryL(intervals[ql].e+1, D);
        assert(lint != -1);
        int lDir = sgn(intervals[lint].diff);
        bool lMore = abs(arr[L] - arr[intervals[lint].s]) >= D;

        /// 3. 오른쪽 끝을 본다
        /// 위에서 본 구간 중 맨 오른쪽 구간을 찾는다.
        int rint = (qr == -1) ? rightInt[D] : tree.queryR(intervals[qr].s-1, D);
        assert(rint != -1);
        int rDir = sgn(intervals[rint].diff);
        bool rMore = abs(arr[R] - arr[intervals[rint].e]) >= D;

        /// 4. 답을 계산한다
        int base = (sum - (sum >= 2 && lDir == -1)) / 2;
        if(lDir == -1 && lMore) base++;
        if(rDir == 1 && rMore) base++;
        return base+1;
    }

    /// 구간 시작점이 하나도 없는 경우
    return 1;
}
# 결과 실행 시간 메모리 Grader output
1 Incorrect 449 ms 8144 KB 2nd lines differ - on the 1st token, expected: '1', found: '3'
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 2 ms 2640 KB 1st lines differ - on the 1st token, expected: '13', found: '14'
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 2 ms 2640 KB 1st lines differ - on the 1st token, expected: '13', found: '14'
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 1333 ms 308756 KB 1st lines differ - on the 1st token, expected: '11903', found: '11905'
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 446 ms 68556 KB 3rd lines differ - on the 1st token, expected: '150', found: '154'
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 2 ms 2640 KB 1st lines differ - on the 1st token, expected: '13', found: '14'
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 449 ms 8144 KB 2nd lines differ - on the 1st token, expected: '1', found: '3'
2 Halted 0 ms 0 KB -