Submission #1196264

#TimeUsernameProblemLanguageResultExecution timeMemory
1196264avighnaHoliday (IOI14_holiday)C++20
100 / 100
3832 ms16020 KiB
#include <vector>
#include <cassert>
#include <functional>
#include <algorithm>

template <typename T> class SegmentTree {
public:
  int n;
  T idt;
  std::vector<T> seg;
  std::function<T(T, T)> f;

  SegmentTree(int n, std::function<T(T, T)> f = std::plus<T>(), T idt = T()) : n(n), idt(idt), f(f), seg(2 * n, idt) {}

  void set(int idx, T x) {
    for (seg[idx += n] = x, idx /= 2; idx > 0; idx /= 2) {
      seg[idx] = f(seg[2 * idx], seg[2 * idx + 1]);
    }
  }

  T query(int l, int r) {
    T ansL = idt, ansR = idt;
    for (l += n, r += n + 1; l < r; l /= 2, r /= 2) {
      if (l & 1) ansL = f(ansL, seg[l++]);
      if (r & 1) ansR = f(seg[--r], ansR);
    }
    return f(ansL, ansR);
  }

  int partition_point(int l, const std::function<bool(T)> &t) {
    T p = idt;
    for (l += n; t(f(p, seg[l])) and l; l /= 2) {
      if (l & 1 and l != 1) p = f(p, seg[l++]);
    }
    if (!l) {
      return n;
    }
    while (l < n) {
      if (t(f(p, seg[l <<= 1]))) p = f(p, seg[l++]);
    }
    return l - n;
  }
};


long long int solve(int n, int start, int d, std::vector<int> attraction) {
  // create a segtree to store the elements
  // you'll have to sort them and assign an index
  // then walk down the tree to find the sum of the first k largest elements
  // first sort
  auto build_a = [&]() {
    std::vector<std::pair<int, int>> a;
    for (int i = 0; i < n; ++i) {
      a.push_back({attraction[i], i});
    }
    std::sort(a.rbegin(), a.rend());
    std::vector<std::pair<int, int>> new_a(n);
    for (int i = 0; i < n; ++i) {
      new_a[a[i].second] = {a[i].first, i};
    }
    return a = std::move(new_a);
  };
  std::vector<std::pair<int, int>> a = build_a();
  // segtree
  struct Node {
    int count = 0;
    int64_t sum = 0;
  };
  auto round_up_pow2 = [](int n) {
    if (n == 0) return 1;
    --n;
    n |= n >> 1;
    n |= n >> 2;
    n |= n >> 4;
    n |= n >> 8;
    n |= n >> 16;
    return n + 1;
  };  
  SegmentTree<Node> st(round_up_pow2(n), [&](Node a, Node b) {
    return Node({a.count + b.count, a.sum + b.sum});
  });

  // for [l, r], youll need to iteratively add everything to segtree. then for [m,r] you first remove the right values (and repeat).
  // for [l, m] remove the left values and newly set right values (and repeat)

  // given starting s, define f(x) with domain [1, n] as the city to jump to (given x stamina) so that the sum you get is optimal
  // property: f(x) <= f(x + 1). any city to the left of and including f(x) will be more optimal with x+1 stamina, so f(x+1)
  //     could either be f(x) or it could be some newer city down the line that improves it
  // now we calculate f(x) [1, n] in O(n log^2 n) time with dnc
  // to compute for [tl, tr] first compute for tm
  auto calc_f = [&](auto &&self, int tl, int tr, int l, int r, int start, std::vector<std::pair<int, int64_t>> &f) {
    int tm = (tl + tr) / 2;
    // comute f(tm)
    std::pair<int64_t, int> best = {-1, -1};
    for (int i = l; i <= r; ++i) {
      st.set(a[i].second, {1, a[i].first});
      // tm-(i-st) stamina: so pick (at max) tm-(i-st) largest vals
      int idx = st.partition_point(0, [&](Node x) {
        return x.count <= tm - (i - start);
      }) - 1;
      // last index with count <= tm - (i-st)
      if (idx == -1) {
        continue;
      }
      best = std::max(best, {st.query(0, idx).sum, -i});
    }
    best.second = -best.second;
    f[tm] = {best.second, best.first};
    if (tl == tr) {
      return;
    }
    // remove right values to prepare for recursive call
    for (int i = f[tm].first; i <= r; ++i) {
      st.set(a[i].second, Node{});
    }
    self(self, tm + 1, tr, f[tm].first, r, start, f);
    // remove all values for left recursive call
    for (int i = l; i <= r; ++i) {
      st.set(a[i].second, Node{});
    }
    self(self, tl, tm, l, f[tm].first, start, f);
  };

  // go from start to right then back to start-1 and left from there
  // so first calculate f1(x) ==> starting point going right
  std::vector<std::pair<int, int64_t>> f1(d + 1);
  calc_f(calc_f, 1, d, start, n - 1, start, f1);
  // reset segtree
  for (int i = 0; i < n; ++i) {
    st.set(i, Node{});
  }
  // then calculate f2(x) ==> starting point (start-1) going left. to do this just reverse the array
  std::reverse(attraction.begin(), attraction.begin() + start);
  for (int i = start; i < n; ++i) {
    attraction[i] = 0;
  }
  a = build_a();
  std::vector<std::pair<int, int64_t>> f2(d + 1);
  calc_f(calc_f, 1, d, 0, n - 1, 0, f2);
  // now iterate over all possible stamina values dedicated to right from 1 to d
  long long int ans = 0;
  for (int i = 1; i <= d; ++i) {
    // we will give i stamina to right and the remaining d-i-[whatever stamina was used to move to start-1] to left
    // if it is nonzero ofc
    long long int cur = f1[i].second;
    int rem = d - i - f1[i].first + start - 1;
    if (rem > 0) {
      cur += f2[rem].second;
    }
    ans = std::max(ans, cur);
  }
  return ans;
}

long long int findMaxAttraction(int n, int start, int d, int attraction[]) {
  std::vector<int> a(n);
  for (int i = 0; i < n; ++i) {
    a[i] = attraction[i];
  }
  long long int ans1 = solve(n, start, d, a);
  std::reverse(a.begin(), a.end());
  start = n - start - 1;
  long long int ans2 = solve(n, start, d, a);
  return std::max(ans1, ans2);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...