이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include "towers.h"
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned ll
#define f first
#define s second
#define pii pair<int,int>
#define pll pair<ll,ll>
#define pb push_back
#define epb emplace_back
using namespace std;
struct node{
int mx, mn;
int mxpos, mnpos;
int mxdif[2];
int sum;
node *l, *r;
node(){
mx = 0, mn = 1e9;
sum = 0;
for(int i = 0; i < 2; i++){
mxdif[i] = 0;
}
l = r = nullptr;
}
node(node *l, node * r) : l(l), r(r){
if(l == nullptr && r == nullptr){
mx = 0;
mn = 1e9;
for(int i = 0; i < 2; i++) mxdif[i] = 0;
}
if(l == nullptr && r){
mx = r->mx;
mn = r->mn;
mxpos = r->mxpos;
mnpos = r->mnpos;
for(int i = 0; i < 2; i++)
mxdif[i] = r->mxdif[i];
sum = r->sum;
}
if(r == nullptr && l){
mx = l->mx;
mn = l->mn;
mxpos = l->mxpos;
mnpos = l->mnpos;
for(int i= 0; i < 2; i++)
mxdif[i] = l->mxdif[i];
sum = l->sum;
}
if(l && r){
mx = max(l->mx, r->mx);
mn = min(l->mn, r->mn);
mx == l->mx ? mxpos = l->mxpos : mxpos = r->mxpos;
mn == l->mn ? mnpos = l->mnpos : mnpos = r->mnpos;
mxdif[0] = max({l->mxdif[0], r->mxdif[0], r->mx - l->mn});
mxdif[1] = max({l->mxdif[1], r->mxdif[1], l->mx - r->mn});
sum = l->sum + r->sum;
}
}
};
int getfirst(node *v, int l, int r, int st, int fin){
if(l > fin || r < st || v == nullptr) return -1;
if(l >= st && r <= fin){
if(v->sum){
while(l != r){
int m = (l + r) / 2;
if(v->l && v->l->sum) r = m, v = v->l;
else v = v->r, l = m + 1;
}
return l;
}
return -1;
}
int m = (l + r) / 2;
int x = getfirst(v->l, l, m, st, fin);
if(x > -1) return x;
return getfirst(v->r, m + 1, r, st, fin);
}
int getlast(node *v, int l, int r, int st, int fin){
if(l > fin || r < st || v == nullptr){
return -1;
}
if(l >= st && r <= fin){
if(v->sum){
while(l != r){
int m = (l + r) / 2;
if(v->r && v->r->sum) v = v->r, l = m + 1;
else v = v->l, r = m;
}
return l;
}
return -1;
}
int m = (l + r) / 2;
int x = getlast(v->r, m + 1, r, st, fin);
if(x > -1) return x;
return getlast(v->l, l, m, st, fin);
}
int getfirst1(node * v, int l, int r, int st, int fin, int val){
if(l > fin || r < st || v == nullptr) return -1;
if(l >= st && r <= fin){
if(v->mx >= val){
while(l != r){
int m = (l + r) / 2;
if(v->l && v->l->mx >= val) r = m, v = v->l;
else l = m +1, v = v->r;
}
return l;
}
return -1;
}
int m = (l + r) / 2;
int x = getfirst1(v->l, l, m, st, fin, val);
if(x > -1) return x;
int y = getfirst1(v->r, m + 1, r, st, fin, val);
return y;
}
int getlast1(node *v, int l, int r, int st, int fin, int val){
if(l > fin || r < st || v == nullptr) return -1;
if(l >= st && r <= fin){
if(v->mx >= val){
while(l != r){
int m = (l + r) / 2;
if(v->r && v->r->mx >= val) l = m + 1, v= v->r;
else r = m, v = v->l;
}
return l;
}
return -1;
}
int m = (l + r) / 2;
int x = getlast1(v->r, m + 1, r, st, fin, val);
if(x > -1) return x;
return getlast1(v->l, l, m, st, fin, val);
}
node *update(node *v, int l, int r, int pos, int val){
//(l > pos || r < pos) return ;
if(v == nullptr) v = new node();
if(l == r){
node *nw = new node();
nw->sum = val;
nw->mx = v->mx; nw->mn = v->mn; nw->mxpos = v->mxpos; nw->mnpos = v->mnpos;
return nw;
}
int m = (l + r) / 2;
if(pos <= m)
return new node(update(v->l, l, m, pos, val), v->r);
return new node(v->l, update(v->r, m + 1, r, pos, val));
}
int getsum(node * v, int l, int r, int st, int fin){
if(l > fin || r < st || v == nullptr) return 0;
if(l >= st && r <= fin){
return v->sum;
}
int m = (l + r) / 2;
return getsum(v->l, l, m, st, fin) + getsum(v->r, m + 1, r, st, fin);
}
node *no;
node *get1(node *v, int l, int r, int st, int fin, int ind){
if(l > fin || r < st || v == nullptr)
return no;
if(l >= st && r <= fin){
return v;
}
int m = (l + r) / 2;
return new node(get1(v->l, l, m, st, fin, ind),
get1(v->r, m + 1, r, st, fin, ind));
}
int get_max(node *v, int l, int r, int st, int fin){
if(l > fin || r < st || v == nullptr)
return 0;
if(l >= st && r <= fin){
return v->mx;
}
int m = (l + r) / 2;
return max(get_max(v->l, l, m, st, fin), get_max(v->r, m + 1, r, st, fin));
}
pii get2(node *v, int l, int r, int st, int fin){
if(l > fin || r < st || v == nullptr) return {1e9, 1e9};
if(l >= st && r <= fin){
return {v->mn, v->mnpos};
}
int m = (l + r) / 2;
return min(get2(v->l, l, m, st, fin),
get2(v->r, m + 1, r, st, fin));
}
const int nmax = 200001;
vector <int> a(nmax);
node *build(int l, int r){
if(l == r){
node *v = new node();
v->mn = v->mx = a[l];
v->mnpos = v->mxpos = l;
return v;
}
int m = (l + r) / 2;
return new node(build(l, m), build(m + 1, r));
}
vector <int> pos;
node *root[nmax];
int n;
void init(int N, std::vector<int> H) {
no = new node();
no->mx = -1e9, no->mn = 1e9;
n = N;
a = H;
root[0] = build(0, n - 1);
stack <int> st;
int l[n], r[n];
for(int i = 0; i < n; i++){
while(!st.empty()){
int v = st.top();
if(a[v] < a[i]) break;
else st.pop();
}
if(st.empty()) l[i] = -1;
else l[i] = st.top();
st.push(i);
}
while(!st.empty()) st.pop();
for(int i = n - 1; i >= 0; i--){
while(!st.empty()){
int v = st.top();
if(a[v] < a[i]) break;
else st.pop();
}
if(st.empty()) r[i] = n;
else r[i] = st.top();
st.push(i);
}
pos.pb(-1e9);
vector <pii> vv;
for(int i = 0; i < n; i++){
// cout << l[i] << ' ' << r[i] << "\n";
int o = 1e9;
if(l[i] != -1) o = min(o, get_max(root[0], 0, n - 1, l[i], i) - a[i]);
if(r[i] != n) o = min(o, get_max(root[0], 0, n - 1, i, r[i]) - a[i]);
vv.pb({o, i});
}
sort(vv.begin(), vv.end());
reverse(vv.begin(),vv.end());
int ind = 1;
for(int i = 0; i < vv.size(); i++){
// cout << root[i]->sum << "\n";
root[i + 1] = update(root[i], 0, n - 1, vv[i].s, 1);
// cout << 1;
pos.pb(-vv[i].f);
}
}
int max_towers(int L, int R, int D) {
int o = upper_bound(pos.begin(), pos.end(), -D) - pos.begin() - 1;
//cout << o << ' ';
int A = getfirst(root[o], 0, n - 1, L, R);
int B = getlast(root[o], 0, n - 1, L, R);
//cout << A << ' ' << B << "\n";
if(A == -1){
pii t = get2(root[o], 0, n - 1, L, R);
//cout << t.s << "\n";
int ans = 1;
int l = getlast1(root[0], 0, n - 1, L, t.s, a[t.s] + D);
// cout << l << ' ';
if(l > -1 && get1(root[0], 0, n - 1, L, l, 0)->mxdif[0] >= D) ans++;
int r = getfirst1(root[0], 0, n - 1, t.s, R, a[t.s] + D);
// cout << r << '\n';
if(r > -1&& get1(root[0], 0 , n - 1, r, R, 1)->mxdif[1] >= D) ans++;
return ans;
}
else{
ll ans = getsum(root[o], 0, n - 1, L, R);
int l = getlast1(root[0], 0, n - 1, L, A, a[A] + D);
// cout << l << ' ';
if(l > -1 && get1(root[0], 0, n - 1, L, l, 0)->mxdif[0] >= D) ans++;
int r = getfirst1(root[0], 0, n - 1, B, R, a[B] + D);
// cout << l <<' ' << r << "\n";
// cout << get1(root[0], 0, n - 1, r, R, 1) << "\n";
if(r > -1 && get1(root[0], 0 , n - 1, r, R, 1)->mxdif[1] >= D) ans++;
return ans;
}
return 0;
}
컴파일 시 표준 에러 (stderr) 메시지
towers.cpp: In function 'void init(int, std::vector<int>)':
towers.cpp:255:22: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<int, int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
255 | for(int i = 0; i < vv.size(); i++){
| ~~^~~~~~~~~~~
towers.cpp:254:9: warning: unused variable 'ind' [-Wunused-variable]
254 | int ind = 1;
| ^~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |