#include "towers.h"
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n;
int arr[100002];
void calculateIntervals();
void init(int N, vector<int> H){
n = N;
for(int i=0; i<n; i++) arr[i] = H[i];
calculateIntervals();
}
/** PST 구조체
* 1. 구간 합 쿼리 처리 기능
* 2. 어떤 pair (x, y)가 들어올 때, 지금까지 들어온 pair의 max 처리 기능
*/
struct PST{
struct Node{
Node *lc, *rc;
int sum;
pair<int, int> lPair, rPair;
Node(Node* node){
lc = node->lc, rc = node->rc;
sum = node->sum;
lPair = node->lPair, rPair = node->rPair;
}
Node(int l, int r){
lc = rc = nullptr;
sum = 0;
lPair = rPair = make_pair(-1, -1);
if(l==r) return;
int m = (l+r)>>1;
lc = new Node(l, m);
rc = new Node(m+1, r);
}
Node* update(int l, int r, int x, int v){
if(l==r){
Node *node = new Node(this);
node->sum += v;
return node;
}
int m = (l+r)>>1;
Node *node = new Node(this);
if(x<=m){
node->lc = lc->update(l, m, x, v);
node->sum += v;
return node;
}
else{
node->rc = rc->update(m+1, r, x, v);
node->sum += v;
return node;
}
}
Node* updateL(int l, int r, int s, int e, pair<int, int> p){
if(r<s || e<l) return this;
if(s<=l && r<=e){
Node *node = new Node(this);
node->lPair = p;
return node;
}
int m = (l+r)>>1;
Node *node = new Node(this);
node->lc = lc->updateL(l, m, s, e, p);
node->rc = rc->updateL(m+1, r, s, e, p);
return node;
}
Node* updateR(int l, int r, int s, int e, pair<int, int> p){
if(r<s || e<l) return this;
if(s<=l && r<=e){
Node *node = new Node(this);
node->rPair = p;
return node;
}
int m = (l+r)>>1;
Node *node = new Node(this);
node->lc = lc->updateR(l, m, s, e, p);
node->rc = rc->updateR(m+1, r, s, e, p);
return node;
}
int query(int l, int r, int s, int e){
if(r<s || e<l) return 0;
if(s<=l && r<=e) return sum;
int m = (l+r)>>1;
return lc->query(l, m, s, e) + rc->query(m+1, r, s, e);
}
pair<int, int> queryL(int l, int r, int x){
if(l==r) return lPair;
int m = (l+r)>>1;
if(x<=m) return max(lPair, lc->queryL(l, m, x));
else return max(lPair, rc->queryL(m+1, r, x));
}
pair<int, int> queryR(int l, int r, int x){
if(l==r) return rPair;
int m = (l+r)>>1;
if(x<=m) return max(rPair, lc->queryR(l, m, x));
else return max(rPair, rc->queryR(m+1, r, x));
}
} *root;
int n, curD;
Node *history[200002];
PST(){}
void init(int N){
n = N;
root = new Node(0, n-1);
curD = 0;
memset(history, 0, sizeof(history));
}
void nextD(){
history[curD++] = root;
}
void update(int x, int v){
root = root->update(0, n-1, x, v);
}
void updateL(int l, int r, int idx){
root = root->updateL(0, n-1, l, r, make_pair(curD, idx));
}
void updateR(int l, int r, int idx){
root = root->updateR(0, n-1, l, r, make_pair(curD, idx));
}
vector<int> query(int l, int r, int d){
/// 정수 3개의 배열을 리턴
vector<int> ret (3);
ret[0] = history[d]->query(0, n-1, l, r);
ret[1] = history[d]->queryL(0, n-1, l).second;
ret[2] = history[d]->queryR(0, n-1, r).second;
return ret;
}
int queryL(int x, int d){
return history[d]->queryL(0, n-1, x).second;
}
int queryR(int x, int d){
return history[d]->queryR(0, n-1, x).second;
}
} tree;
struct Interval{
int s, e, diff, idx;
Interval(int s, int e, int diff, int idx): s(s), e(e), diff(diff), idx(idx){}
};
struct xOrder {
bool operator() (const Interval &l, const Interval &r)const{
return l.s < r.s;
}
};
struct diffOrder {
bool operator() (const Interval &l, const Interval &r)const{
if(abs(l.diff) != abs(r.diff)) return abs(l.diff) < abs(r.diff);
return l.s < r.s;
}
};
set<Interval, xOrder> xSet;
set<Interval, diffOrder> diffSet;
int sgn(int x){
return x > 0 ? 1 : x == 0 ? 0 : -1;
}
vector<int> diffScale;
vector<Interval> intervals;
int leftInt[200002], rightInt[200002];
void makeInitialIntervals(){ /// 초기의 구간들을 계산한다
tree.init(n);
int prv = 0;
while(prv < n-1){
int j = prv;
while(j+1<n && sgn(arr[j] - arr[prv]) * sgn(arr[j+1] - arr[j]) >= 0) j++;
if(arr[j] == arr[prv]) break;
/// 현재 찾은 구간을 셋에 집어넣는다
int idx = (int)intervals.size();
Interval tmp = Interval(prv, j, arr[j] - arr[prv], idx);
intervals.push_back(tmp);
xSet.insert(tmp);
diffSet.insert(tmp);
/// 현재 찾은 구간을 트리에 업데이트한다
tree.update(prv, 1);
tree.updateL(prv+1, j, idx);
tree.updateR(prv, j-1, idx);
prv = j;
}
diffScale.push_back(0);
leftInt[tree.curD] = (xSet.empty() ? -1 : xSet.begin()->idx);
rightInt[tree.curD] = (xSet.empty() ? -1 : xSet.rbegin()->idx);
tree.nextD();
}
void calculateChanges(){ /// D가 커짐에 따라 생기는 변화들을 구한다
while((int)diffSet.size()){
int D = abs(diffSet.begin()->diff); /// 현재의 차이값
while(abs((int)diffSet.begin()->diff) == D){ /// 해당하는 차이의 모든 구간을 처리한다.
/// 구간을 뽑는다
Interval p = *diffSet.begin();
diffSet.erase(diffSet.begin());
/// Case 1. 직전에 구간이 존재하지 않음
if(xSet.begin()->s == p.s){ /// 이 구간은 그냥 삭제한다.
xSet.erase(xSet.begin());
tree.updateL(0, p.e, -1);
}
/// Case 2. 직후에 구간이 존재하지 않음
else if(xSet.rbegin()->s == p.s){ /// 이 구간 역시 그냥 삭제한다.
xSet.erase(prev(xSet.end()));
tree.updateR(p.s, n-1, -1);
}
/// Case 3. 양쪽 모두 잘 존재함
else{ /// 새로운 구간을 만들어 준다.
auto it = xSet.find(p);
int ns = prev(it)->s, ne = next(it)->e;
int idx = (int)intervals.size();
Interval np = Interval(ns, ne, arr[ne] - arr[ns], idx);
intervals.push_back(np);
/// 셋에 반영한다.
diffSet.erase(*prev(it)), diffSet.erase(*next(it));
xSet.erase(prev(it)), xSet.erase(next(it));
xSet.erase(it);
xSet.insert(np), diffSet.insert(np);
/// 트리에 반영한다
tree.update(p.s, -1);
tree.update(p.e, -1);
tree.updateL(np.s+1, np.e, idx);
tree.updateR(np.s, np.e-1, idx);
}
}
diffScale.push_back(D+1);
leftInt[tree.curD] = (xSet.empty() ? -1 : xSet.begin()->idx);
rightInt[tree.curD] = (xSet.empty() ? -1 : xSet.rbegin()->idx);
tree.nextD();
}
leftInt[tree.curD] = rightInt[tree.curD] = -1;
}
void calculateIntervals(){
makeInitialIntervals();
calculateChanges();
}
int max_towers(int L, int R, int D){
D = upper_bound(diffScale.begin(), diffScale.end(), D) - diffScale.begin() - 1;
if(leftInt[D] == -1) return 1; /// 현 시기에 구간이 하나도 없는 경우
vector<int> response = tree.query(L, R, D);
int sum = response[0], ql = response[1], qr = response[2];
/// 구간 시작점이 하나라도 있는 경우
if(sum){
/// 1. 완전히 속하는 구간의 개수를 센다.
/// 왼쪽 끝 부분은 예외가 생기지 않는다.
/// 오른쪽 끝 부분은 맨 오른쪽 점이 어느 구간에 포함되지 않았을 때를 제외하고는 1을 제외해야 한다.
int insideIntervalCount = sum;
if(qr != -1) insideIntervalCount--; /// 시작점이
/// 2. 왼쪽 끝을 본다.
/// 위에서 본 구간 중 맨 왼쪽 구간을 찾는다.
int lint = (ql == -1) ? leftInt[D] : tree.queryL(intervals[ql].e+1, D);
assert(lint != -1);
int lDir = sgn(intervals[lint].diff);
bool lMore = abs(arr[L] - arr[intervals[lint].s]) >= D;
/// 3. 오른쪽 끝을 본다
/// 위에서 본 구간 중 맨 오른쪽 구간을 찾는다.
int rint = (qr == -1) ? rightInt[D] : tree.queryR(intervals[qr].s-1, D);
assert(rint != -1);
int rDir = sgn(intervals[rint].diff);
bool rMore = abs(arr[R] - arr[intervals[rint].e]) >= D;
/// 4. 답을 계산한다
int base = (sum - (sum >= 2 && lDir == -1)) / 2;
if(lDir == -1 && lMore) base++;
if(rDir == 1 && rMore) base++;
return base+1;
}
/// 구간 시작점이 하나도 없는 경우
return 1;
}
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
449 ms |
8144 KB |
2nd lines differ - on the 1st token, expected: '1', found: '3' |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
2 ms |
2640 KB |
1st lines differ - on the 1st token, expected: '13', found: '14' |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
2 ms |
2640 KB |
1st lines differ - on the 1st token, expected: '13', found: '14' |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
1333 ms |
308756 KB |
1st lines differ - on the 1st token, expected: '11903', found: '11905' |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
446 ms |
68556 KB |
3rd lines differ - on the 1st token, expected: '150', found: '154' |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
2 ms |
2640 KB |
1st lines differ - on the 1st token, expected: '13', found: '14' |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
449 ms |
8144 KB |
2nd lines differ - on the 1st token, expected: '1', found: '3' |
2 |
Halted |
0 ms |
0 KB |
- |