제출 #803034

#제출 시각아이디문제언어결과실행 시간메모리
803034Ronin13Radio Towers (IOI22_towers)C++17
100 / 100
1649 ms430372 KiB
#include "towers.h"
#include <bits/stdc++.h>
#define ll long long 
#define ull unsigned ll
#define f first
#define s second
#define pii pair<int,int>
#define pll pair<ll,ll>
#define pb push_back
#define epb emplace_back
using namespace std;

struct node{
    int mx, mn;
    int mxpos, mnpos;
    int mxdif[2];
    int sum;
    node *l, *r;
    node(){
      mx = 0, mn = 1e9;
      sum = 0;
      for(int i = 0; i < 2; i++){
        mxdif[i] = 0;
      }
      l = r = nullptr;
    }
    node(node *l, node * r) : l(l), r(r){
      if(l == nullptr && r == nullptr){
        mx = 0;
        mn = 1e9;
        for(int i = 0; i < 2; i++) mxdif[i] = 0;
      }
      if(l == nullptr && r){
        mx = r->mx;
        mn = r->mn;
        mxpos = r->mxpos;
        mnpos = r->mnpos;
        for(int i = 0; i < 2; i++)
          mxdif[i] = r->mxdif[i];
        sum = r->sum;
      }
      if(r == nullptr && l){
        mx = l->mx;
        mn = l->mn;
        mxpos = l->mxpos;
        mnpos = l->mnpos;
        for(int i= 0; i < 2; i++)
        mxdif[i] = l->mxdif[i];
        sum = l->sum;
      }
      if(l && r){
        mx = max(l->mx, r->mx);
        mn = min(l->mn, r->mn);
        mx == l->mx ? mxpos = l->mxpos : mxpos = r->mxpos;
        mn == l->mn ? mnpos = l->mnpos : mnpos = r->mnpos;
        mxdif[0] = max({l->mxdif[0], r->mxdif[0], r->mx - l->mn});
        mxdif[1] = max({l->mxdif[1], r->mxdif[1], l->mx - r->mn});
        sum = l->sum + r->sum;
      }
    }
};

int getfirst(node *v, int l, int r, int st, int fin){
    if(l > fin || r < st || v == nullptr) return -1;
    if(l >= st && r <= fin){
      if(v->sum){
          while(l != r){
              int m = (l + r) / 2;
              if(v->l && v->l->sum) r = m, v = v->l;
              else v = v->r, l = m + 1;
          }
          return l;
      } 
      return -1;
    }
    int m = (l + r) / 2;
    int x = getfirst(v->l, l, m, st, fin);
    if(x > -1) return x;
    return getfirst(v->r, m + 1, r, st, fin);
} 

int getlast(node *v, int l, int r, int st, int fin){
    if(l > fin || r < st || v == nullptr){
      return -1;
    }
    if(l >= st && r <= fin){
      if(v->sum){
          while(l != r){
              int m = (l + r) / 2;
              if(v->r && v->r->sum) v = v->r, l = m + 1;
              else v = v->l, r = m;
          }
          return l;
      } 
      return -1;
    }
    int m = (l + r) / 2;
    int x = getlast(v->r, m + 1, r, st, fin);
    if(x > -1) return x;
    return getlast(v->l, l, m, st, fin);
}

int getfirst1(node * v, int l, int r, int st, int fin, int val){
    if(l > fin || r < st || v == nullptr) return -1;
    if(l >= st && r <= fin){
      if(v->mx >= val){
        while(l != r){
          int m = (l + r) / 2;
          if(v->l && v->l->mx >= val) r = m, v = v->l;
          else l = m +1, v = v->r;
        }
        return l;
      }
      return -1;
    }
    int m = (l + r) / 2;
    int x = getfirst1(v->l, l, m, st, fin, val);
    if(x > -1) return x;
    int y = getfirst1(v->r, m + 1, r, st, fin, val);
    return y;
}

int getlast1(node *v, int l, int r, int st, int fin, int val){
  if(l > fin || r < st || v == nullptr) return -1;
  if(l >= st && r <= fin){
      if(v->mx >= val){
        while(l != r){
          int m = (l + r) / 2;
          if(v->r && v->r->mx >= val) l = m + 1,  v= v->r;
          else r = m, v = v->l;
        }
        return l;
      }
      return -1;
  }
  int m = (l + r) / 2;
  int x = getlast1(v->r, m + 1, r, st, fin, val);
  if(x > -1) return x;
  return getlast1(v->l, l, m, st, fin, val);
}

node *update(node *v, int l, int r, int pos, int val){
    //(l > pos || r < pos) return ;
    if(v == nullptr) v = new node();
    if(l == r){
        node *nw = new node();
        nw->sum = val;
        nw->mx = v->mx; nw->mn = v->mn; nw->mxpos = v->mxpos; nw->mnpos = v->mnpos;
        return nw;
    }

    int m = (l + r) / 2;
    if(pos <= m)
      return new node(update(v->l, l, m, pos, val), v->r);
    return new node(v->l, update(v->r, m + 1, r, pos, val));
}

int getsum(node * v, int l, int r, int st, int fin){
    if(l > fin || r < st || v == nullptr) return 0;
    if(l >= st && r <= fin){
      return v->sum;
    }
    int m = (l + r) / 2;
    return getsum(v->l, l, m, st, fin) + getsum(v->r, m + 1, r, st, fin);
}
node *no;
node *get1(node *v, int l, int r, int st, int fin, int ind){
  if(l > fin || r < st || v == nullptr) 
    return no;
  if(l >= st && r <= fin){
    return v;
  } 
  int m = (l + r) / 2;
  return new node(get1(v->l, l, m, st, fin, ind),
                  get1(v->r, m + 1, r, st, fin, ind));
}

int get_max(node *v, int l, int r, int st, int fin){
    if(l > fin || r < st || v == nullptr) 
    return 0;
  if(l >= st && r <= fin){
    return v->mx;
  } 
  int m = (l + r) / 2;
  return max(get_max(v->l, l, m, st, fin), get_max(v->r, m + 1, r, st, fin));
}

pii get2(node *v, int l, int r, int st, int fin){
    if(l > fin || r < st || v == nullptr) return {1e9, 1e9};
    if(l >= st && r <= fin){
      return {v->mn, v->mnpos};
    }
    int m = (l + r) / 2;
    return min(get2(v->l, l, m, st, fin),
              get2(v->r, m + 1, r, st, fin));
}
const int nmax = 200001;
vector <int> a(nmax);

node *build(int l, int r){
    if(l == r){
        node *v = new node();
        v->mn = v->mx = a[l];
        v->mnpos = v->mxpos = l;
        return v;
   }
   int m = (l + r) / 2;
    return new node(build(l, m), build(m + 1, r));
}

vector <int> pos;
node *root[nmax];
int n;
void init(int N, std::vector<int> H) {
  no = new node();
  no->mx = -1e9, no->mn = 1e9;
  n = N;
    a = H;
    root[0] = build(0, n - 1);
    stack <int> st;
    int l[n], r[n];
    for(int i = 0; i < n; i++){
      while(!st.empty()){
        int v = st.top();
        if(a[v] < a[i]) break;
        else st.pop();
      }
      if(st.empty()) l[i] = -1;
      else l[i] = st.top();
      st.push(i);
    }
    while(!st.empty()) st.pop();
    for(int i = n - 1; i >= 0; i--){
        while(!st.empty()){
          int v = st.top();
          if(a[v] < a[i]) break;
          else st.pop();
        }
        if(st.empty()) r[i] = n;
        else r[i] = st.top();
        st.push(i);
    }
    pos.pb(-1e9);
    vector <pii> vv;
    for(int i = 0; i < n; i++){
     // cout << l[i] << ' ' << r[i] << "\n";
        int o = 1e9;
        if(l[i] != -1) o = min(o, get_max(root[0], 0, n - 1, l[i], i) - a[i]);
        if(r[i] != n) o = min(o, get_max(root[0], 0, n - 1, i, r[i]) - a[i]);
        vv.pb({o, i});
    }
    sort(vv.begin(), vv.end());
    reverse(vv.begin(),vv.end());
    int ind = 1;
    for(int i = 0; i < vv.size(); i++){
    //  cout << root[i]->sum << "\n";
        root[i + 1] = update(root[i], 0, n - 1, vv[i].s, 1);
       // cout << 1;
        pos.pb(-vv[i].f);
    }
    
}

int max_towers(int L, int R, int D) {
  int o = upper_bound(pos.begin(), pos.end(), -D) - pos.begin() - 1;
  //cout << o << ' ';
  int A = getfirst(root[o], 0, n - 1, L, R);
  int B = getlast(root[o], 0, n - 1, L, R);
  //cout << A << ' ' << B << "\n";
  if(A == -1){
      pii t = get2(root[o], 0, n - 1, L, R);
      //cout << t.s << "\n";
      int ans = 1;
      int l = getlast1(root[0], 0, n - 1, L, t.s, a[t.s] + D);
    //  cout << l << ' ';
      if(l > -1 && get1(root[0], 0, n - 1, L, l, 0)->mxdif[0] >= D) ans++;
      int r = getfirst1(root[0], 0, n - 1, t.s, R, a[t.s] + D);
    //  cout << r << '\n';
      if(r > -1&& get1(root[0], 0 , n - 1, r, R, 1)->mxdif[1] >= D) ans++;
      return ans;
  }
  else{
    ll ans = getsum(root[o], 0, n - 1, L, R);
    
    int l = getlast1(root[0], 0, n - 1, L, A, a[A] + D);
   // cout << l << ' ';
    if(l > -1 && get1(root[0], 0, n - 1, L, l, 0)->mxdif[0] >= D) ans++;
    int r = getfirst1(root[0], 0, n - 1, B, R, a[B] + D);
   // cout << l <<' ' << r << "\n";
   // cout << get1(root[0], 0, n - 1, r, R, 1) << "\n";
    if(r > -1 && get1(root[0], 0 , n - 1, r, R, 1)->mxdif[1] >= D) ans++;
    return ans;
  }
  return 0;
}

컴파일 시 표준 에러 (stderr) 메시지

towers.cpp: In function 'void init(int, std::vector<int>)':
towers.cpp:255:22: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<int, int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  255 |     for(int i = 0; i < vv.size(); i++){
      |                    ~~^~~~~~~~~~~
towers.cpp:254:9: warning: unused variable 'ind' [-Wunused-variable]
  254 |     int ind = 1;
      |         ^~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...