제출 #803034

#제출 시각아이디문제언어결과실행 시간메모리
803034Ronin13송신탑 (IOI22_towers)C++17
100 / 100
1649 ms430372 KiB
#include "towers.h" #include <bits/stdc++.h> #define ll long long #define ull unsigned ll #define f first #define s second #define pii pair<int,int> #define pll pair<ll,ll> #define pb push_back #define epb emplace_back using namespace std; struct node{ int mx, mn; int mxpos, mnpos; int mxdif[2]; int sum; node *l, *r; node(){ mx = 0, mn = 1e9; sum = 0; for(int i = 0; i < 2; i++){ mxdif[i] = 0; } l = r = nullptr; } node(node *l, node * r) : l(l), r(r){ if(l == nullptr && r == nullptr){ mx = 0; mn = 1e9; for(int i = 0; i < 2; i++) mxdif[i] = 0; } if(l == nullptr && r){ mx = r->mx; mn = r->mn; mxpos = r->mxpos; mnpos = r->mnpos; for(int i = 0; i < 2; i++) mxdif[i] = r->mxdif[i]; sum = r->sum; } if(r == nullptr && l){ mx = l->mx; mn = l->mn; mxpos = l->mxpos; mnpos = l->mnpos; for(int i= 0; i < 2; i++) mxdif[i] = l->mxdif[i]; sum = l->sum; } if(l && r){ mx = max(l->mx, r->mx); mn = min(l->mn, r->mn); mx == l->mx ? mxpos = l->mxpos : mxpos = r->mxpos; mn == l->mn ? mnpos = l->mnpos : mnpos = r->mnpos; mxdif[0] = max({l->mxdif[0], r->mxdif[0], r->mx - l->mn}); mxdif[1] = max({l->mxdif[1], r->mxdif[1], l->mx - r->mn}); sum = l->sum + r->sum; } } }; int getfirst(node *v, int l, int r, int st, int fin){ if(l > fin || r < st || v == nullptr) return -1; if(l >= st && r <= fin){ if(v->sum){ while(l != r){ int m = (l + r) / 2; if(v->l && v->l->sum) r = m, v = v->l; else v = v->r, l = m + 1; } return l; } return -1; } int m = (l + r) / 2; int x = getfirst(v->l, l, m, st, fin); if(x > -1) return x; return getfirst(v->r, m + 1, r, st, fin); } int getlast(node *v, int l, int r, int st, int fin){ if(l > fin || r < st || v == nullptr){ return -1; } if(l >= st && r <= fin){ if(v->sum){ while(l != r){ int m = (l + r) / 2; if(v->r && v->r->sum) v = v->r, l = m + 1; else v = v->l, r = m; } return l; } return -1; } int m = (l + r) / 2; int x = getlast(v->r, m + 1, r, st, fin); if(x > -1) return x; return getlast(v->l, l, m, st, fin); } int getfirst1(node * v, int l, int r, int st, int fin, int val){ if(l > fin || r < st || v == nullptr) return -1; if(l >= st && r <= fin){ if(v->mx >= val){ while(l != r){ int m = (l + r) / 2; if(v->l && v->l->mx >= val) r = m, v = v->l; else l = m +1, v = v->r; } return l; } return -1; } int m = (l + r) / 2; int x = getfirst1(v->l, l, m, st, fin, val); if(x > -1) return x; int y = getfirst1(v->r, m + 1, r, st, fin, val); return y; } int getlast1(node *v, int l, int r, int st, int fin, int val){ if(l > fin || r < st || v == nullptr) return -1; if(l >= st && r <= fin){ if(v->mx >= val){ while(l != r){ int m = (l + r) / 2; if(v->r && v->r->mx >= val) l = m + 1, v= v->r; else r = m, v = v->l; } return l; } return -1; } int m = (l + r) / 2; int x = getlast1(v->r, m + 1, r, st, fin, val); if(x > -1) return x; return getlast1(v->l, l, m, st, fin, val); } node *update(node *v, int l, int r, int pos, int val){ //(l > pos || r < pos) return ; if(v == nullptr) v = new node(); if(l == r){ node *nw = new node(); nw->sum = val; nw->mx = v->mx; nw->mn = v->mn; nw->mxpos = v->mxpos; nw->mnpos = v->mnpos; return nw; } int m = (l + r) / 2; if(pos <= m) return new node(update(v->l, l, m, pos, val), v->r); return new node(v->l, update(v->r, m + 1, r, pos, val)); } int getsum(node * v, int l, int r, int st, int fin){ if(l > fin || r < st || v == nullptr) return 0; if(l >= st && r <= fin){ return v->sum; } int m = (l + r) / 2; return getsum(v->l, l, m, st, fin) + getsum(v->r, m + 1, r, st, fin); } node *no; node *get1(node *v, int l, int r, int st, int fin, int ind){ if(l > fin || r < st || v == nullptr) return no; if(l >= st && r <= fin){ return v; } int m = (l + r) / 2; return new node(get1(v->l, l, m, st, fin, ind), get1(v->r, m + 1, r, st, fin, ind)); } int get_max(node *v, int l, int r, int st, int fin){ if(l > fin || r < st || v == nullptr) return 0; if(l >= st && r <= fin){ return v->mx; } int m = (l + r) / 2; return max(get_max(v->l, l, m, st, fin), get_max(v->r, m + 1, r, st, fin)); } pii get2(node *v, int l, int r, int st, int fin){ if(l > fin || r < st || v == nullptr) return {1e9, 1e9}; if(l >= st && r <= fin){ return {v->mn, v->mnpos}; } int m = (l + r) / 2; return min(get2(v->l, l, m, st, fin), get2(v->r, m + 1, r, st, fin)); } const int nmax = 200001; vector <int> a(nmax); node *build(int l, int r){ if(l == r){ node *v = new node(); v->mn = v->mx = a[l]; v->mnpos = v->mxpos = l; return v; } int m = (l + r) / 2; return new node(build(l, m), build(m + 1, r)); } vector <int> pos; node *root[nmax]; int n; void init(int N, std::vector<int> H) { no = new node(); no->mx = -1e9, no->mn = 1e9; n = N; a = H; root[0] = build(0, n - 1); stack <int> st; int l[n], r[n]; for(int i = 0; i < n; i++){ while(!st.empty()){ int v = st.top(); if(a[v] < a[i]) break; else st.pop(); } if(st.empty()) l[i] = -1; else l[i] = st.top(); st.push(i); } while(!st.empty()) st.pop(); for(int i = n - 1; i >= 0; i--){ while(!st.empty()){ int v = st.top(); if(a[v] < a[i]) break; else st.pop(); } if(st.empty()) r[i] = n; else r[i] = st.top(); st.push(i); } pos.pb(-1e9); vector <pii> vv; for(int i = 0; i < n; i++){ // cout << l[i] << ' ' << r[i] << "\n"; int o = 1e9; if(l[i] != -1) o = min(o, get_max(root[0], 0, n - 1, l[i], i) - a[i]); if(r[i] != n) o = min(o, get_max(root[0], 0, n - 1, i, r[i]) - a[i]); vv.pb({o, i}); } sort(vv.begin(), vv.end()); reverse(vv.begin(),vv.end()); int ind = 1; for(int i = 0; i < vv.size(); i++){ // cout << root[i]->sum << "\n"; root[i + 1] = update(root[i], 0, n - 1, vv[i].s, 1); // cout << 1; pos.pb(-vv[i].f); } } int max_towers(int L, int R, int D) { int o = upper_bound(pos.begin(), pos.end(), -D) - pos.begin() - 1; //cout << o << ' '; int A = getfirst(root[o], 0, n - 1, L, R); int B = getlast(root[o], 0, n - 1, L, R); //cout << A << ' ' << B << "\n"; if(A == -1){ pii t = get2(root[o], 0, n - 1, L, R); //cout << t.s << "\n"; int ans = 1; int l = getlast1(root[0], 0, n - 1, L, t.s, a[t.s] + D); // cout << l << ' '; if(l > -1 && get1(root[0], 0, n - 1, L, l, 0)->mxdif[0] >= D) ans++; int r = getfirst1(root[0], 0, n - 1, t.s, R, a[t.s] + D); // cout << r << '\n'; if(r > -1&& get1(root[0], 0 , n - 1, r, R, 1)->mxdif[1] >= D) ans++; return ans; } else{ ll ans = getsum(root[o], 0, n - 1, L, R); int l = getlast1(root[0], 0, n - 1, L, A, a[A] + D); // cout << l << ' '; if(l > -1 && get1(root[0], 0, n - 1, L, l, 0)->mxdif[0] >= D) ans++; int r = getfirst1(root[0], 0, n - 1, B, R, a[B] + D); // cout << l <<' ' << r << "\n"; // cout << get1(root[0], 0, n - 1, r, R, 1) << "\n"; if(r > -1 && get1(root[0], 0 , n - 1, r, R, 1)->mxdif[1] >= D) ans++; return ans; } return 0; }

컴파일 시 표준 에러 (stderr) 메시지

towers.cpp: In function 'void init(int, std::vector<int>)':
towers.cpp:255:22: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<int, int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  255 |     for(int i = 0; i < vv.size(); i++){
      |                    ~~^~~~~~~~~~~
towers.cpp:254:9: warning: unused variable 'ind' [-Wunused-variable]
  254 |     int ind = 1;
      |         ^~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...