제출 #1339059

#제출 시각아이디문제언어결과실행 시간메모리
1339059madamadam3Radio Towers (IOI22_towers)C++20
17 / 100
2137 ms1177820 KiB
#include "towers.h"
#include <bits/stdc++.h>

using namespace std;
#define all(x) (x).begin(), (x).end()
using vi = vector<int>;
using pi = pair<int, int>;

struct SegTree {
  int n;
  vi arr;
  vector<pi> st;

  SegTree() {}
  SegTree(int n, const vi& arr) : n(n), arr(arr), st(4 * n, {INT_MIN, INT_MAX}) {
    build(0, 0, n);
  }

  pi combine(pi a, pi b) { return {max(a.first, b.first), min(a.second, b.second)}; }

  pi build(int i, int l, int r) {
    if (l + 1 == r) return st[i] = {arr[l], arr[l]};
    int m = (l + r) >> 1;
    return st[i] = combine(build(2 * i + 1, l, m), build(2 * i + 2, m, r));
  }

  int query(int ql, int qr, bool get_max = true) {
    if (ql >= qr) return get_max ? INT_MIN : INT_MAX;
    auto v = query(0, 0, n, ql, qr);
    return get_max ? v.first : v.second;
  }

  pi query(int i, int l, int r, int ql, int qr) {
    if (qr <= l || r <= ql) return {INT_MIN, INT_MAX};
    if (ql <= l && r <= qr) return st[i];
    int m = (l + r) >> 1;
    return combine(query(2 * i + 1, l, m, ql, qr), query(2 * i + 2, m, r, ql, qr));
  }
};

static const int INF = 1e9;

struct Persistent2DCounter {
  struct InnerNode {
    int l = 0, r = 0, sum = 0;
  };
  struct OuterNode {
    int l = 0, r = 0, inner = 0;
  };
  struct Point {
    int key; // active iff key >= D
    int b;   // outer dimension
    int a;   // inner dimension
  };

  int M = 0; // coordinate domain size [0, M)
  vector<InnerNode> inner_nodes;
  vector<OuterNode> outer_nodes;
  vector<int> keys_desc;
  vector<int> roots;

  Persistent2DCounter() {}

  void init(int domain_size, vector<Point> pts) {
    M = domain_size;
    inner_nodes.clear();
    outer_nodes.clear();
    keys_desc.clear();
    roots.clear();

    inner_nodes.reserve(1 + (long long)pts.size() * 20 * 20);
    outer_nodes.reserve(1 + (long long)pts.size() * 20);

    inner_nodes.push_back({});
    outer_nodes.push_back({});

    sort(all(pts), [&](const Point& x, const Point& y) {
      if (x.key != y.key) return x.key > y.key;
      if (x.b != y.b) return x.b < y.b;
      return x.a < y.a;
    });

    int cur_root = 0;
    for (int i = 0; i < (int)pts.size();) {
      int j = i;
      while (j < (int)pts.size() && pts[j].key == pts[i].key) {
        cur_root = outer_update(cur_root, 0, M, pts[j].b, pts[j].a);
        j++;
      }
      keys_desc.push_back(pts[i].key);
      roots.push_back(cur_root);
      i = j;
    }
  }

  int query(int D, int bql, int bqr, int aql, int aqr) const {
    int ver = version_for_D(D);
    if (ver < 0 || bql >= bqr || aql >= aqr) return 0;
    return outer_query(roots[ver], 0, M, bql, bqr, aql, aqr);
  }

private:
  int clone_inner(int p) {
    inner_nodes.push_back(inner_nodes[p]);
    return (int)inner_nodes.size() - 1;
  }
  int clone_outer(int p) {
    outer_nodes.push_back(outer_nodes[p]);
    return (int)outer_nodes.size() - 1;
  }

  int inner_update(int p, int l, int r, int pos) {
    int u = clone_inner(p);
    inner_nodes[u].sum++;
    if (l + 1 == r) return u;
    int m = (l + r) >> 1;
    if (pos < m) inner_nodes[u].l = inner_update(inner_nodes[u].l, l, m, pos);
    else         inner_nodes[u].r = inner_update(inner_nodes[u].r, m, r, pos);
    return u;
  }

  int outer_update(int p, int l, int r, int bpos, int apos) {
    int u = clone_outer(p);
    outer_nodes[u].inner = inner_update(outer_nodes[u].inner, 0, M, apos);
    if (l + 1 == r) return u;
    int m = (l + r) >> 1;
    if (bpos < m) outer_nodes[u].l = outer_update(outer_nodes[u].l, l, m, bpos, apos);
    else          outer_nodes[u].r = outer_update(outer_nodes[u].r, m, r, bpos, apos);
    return u;
  }

  int inner_query(int p, int l, int r, int ql, int qr) const {
    if (!p || qr <= l || r <= ql) return 0;
    if (ql <= l && r <= qr) return inner_nodes[p].sum;
    int m = (l + r) >> 1;
    return inner_query(inner_nodes[p].l, l, m, ql, qr) +
           inner_query(inner_nodes[p].r, m, r, ql, qr);
  }

  int outer_query(int p, int l, int r, int ql, int qr, int aql, int aqr) const {
    if (!p || qr <= l || r <= ql) return 0;
    if (ql <= l && r <= qr) {
      return inner_query(outer_nodes[p].inner, 0, M, aql, aqr);
    }
    int m = (l + r) >> 1;
    return outer_query(outer_nodes[p].l, l, m, ql, qr, aql, aqr) +
           outer_query(outer_nodes[p].r, m, r, ql, qr, aql, aqr);
  }

  int version_for_D(int D) const {
    int lo = 0, hi = (int)keys_desc.size() - 1, ans = -1;
    while (lo <= hi) {
      int mid = (lo + hi) >> 1;
      if (keys_desc[mid] >= D) {
        ans = mid;
        lo = mid + 1;
      } else {
        hi = mid - 1;
      }
    }
    return ans;
  }
};

int n;
vi h, idx, first_left, first_right;
SegTree st;
bitset<100000> alive_fwd, alive_rev;

int delta_left[100000], delta_right[100000];

int rpos(int i) { return (int)alive_rev.size() - i - 1; }

Persistent2DCounter ds_mid, ds_left, ds_right;

void init(int N, vector<int> H) {
  n = N;
  h = H;
  idx.resize(n);
  iota(all(idx), 0);
  sort(all(idx), [&](int i, int j) { return h[i] < h[j]; });

  first_left.assign(n, -1);
  first_right.assign(n, n);
  st = SegTree(n, h);

  alive_fwd.reset();
  alive_rev.reset();

  for (int i : idx) {
    int prev = rpos((int)alive_rev._Find_next(rpos(i)));
    int nxt = (int)alive_fwd._Find_next(i);

    first_left[i] = prev;
    first_right[i] = min(n, nxt);

    delta_left[i] = (prev < 0 ? INF : st.query(prev, i + 1) - h[i]);
    delta_right[i] = (nxt >= n ? INF : st.query(i, nxt + 1) - h[i]);

    alive_fwd.set(i);
    alive_rev.set(rpos(i));
  }

  // Shift first_left by +1 so domain becomes [0..n].
  // Keep first_right as-is, so domain is also [0..n].
  int M = n + 1;

  vector<Persistent2DCounter::Point> pts_mid, pts_left, pts_right;
  pts_mid.reserve(n);
  pts_left.reserve(n);
  pts_right.reserve(n);

  for (int i = 0; i < n; i++) {
    int A = first_left[i] + 1; // in [0..n]
    int B = first_right[i];    // in [0..n]

    // first_left[i] >= L && first_right[i] <= R && min(delta_left, delta_right) >= D
    if (first_left[i] >= 0 && first_right[i] < n) {
      pts_mid.push_back({min(delta_left[i], delta_right[i]), B, A});
    }

    // first_left[i] < L && first_right[i] <= R && delta_right[i] >= D
    if (first_right[i] < n) {
      pts_left.push_back({delta_right[i], B, A});
    }

    // first_left[i] >= L && first_right[i] > R && delta_left[i] >= D
    if (first_left[i] >= 0) {
      pts_right.push_back({delta_left[i], B, A});
    }
  }

  ds_mid.init(M, pts_mid);
  ds_left.init(M, pts_left);
  ds_right.init(M, pts_right);
}

int max_towers(int L, int R, int D) {
  int t = 1;

  // first_left[i] >= L && first_right[i] <= R && min(delta_left[i], delta_right[i]) >= D
  // A = first_left+1 >= L+1, B = first_right <= R
  t += ds_mid.query(D, 0, R + 1, L + 1, n + 1);

  // first_left[i] < L && first_right[i] <= R && delta_right[i] >= D
  // A = first_left+1 <= L, B = first_right <= R
  t += ds_left.query(D, 0, R + 1, 0, L + 1);

  // first_left[i] >= L && first_right[i] > R && delta_left[i] >= D
  // A = first_left+1 >= L+1, B = first_right >= R+1
  t += ds_right.query(D, R + 1, n + 1, L + 1, n + 1);

  return t;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...