Submission #1262572

#TimeUsernameProblemLanguageResultExecution timeMemory
1262572kunzaZa183Holiday (IOI14_holiday)C++20
100 / 100
1166 ms18856 KiB
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

vector<ll> calc(vector<int> cur) {

  int n = cur.size();
  int maxd = 2 * n + 1;

  vector<pair<int, int>> helper;
  for (int i = 0; i < n; i++) {
    helper.emplace_back(cur[i], i);
  }

  sort(helper.rbegin(), helper.rend());

  vector<int> loc(n);
  for (int i = 0; i < n; i++)
    loc[helper[i].second] = i;

  vector<ll> dp(maxd + 1);
  vector<int> to(maxd + 1);

  vector<array<int, 4>> vpii;
  vpii.push_back({0, maxd, 0, n});
  while (!vpii.empty()) {
    vector<int> ct(4 * n);
    vector<ll> sum(4 * n);

    int qin;
    function<void(int, int, int)> upd = [&](int curin, int curl, int curr) {
      if (qin > curr || qin < curl)
        return;
      if (curl == curr) {
        ct[curin] = 1;
        sum[curin] = helper[curl].first;
        return;
      }

      upd(curin * 2 + 1, curl, (curl + curr) / 2);
      upd(curin * 2 + 2, (curl + curr) / 2 + 1, curr);

      ct[curin] = ct[curin * 2 + 1] + ct[curin * 2 + 2];
      sum[curin] = sum[curin * 2 + 1] + sum[curin * 2 + 2];
    };

    vector<array<int, 4>> vpii2;
    int in = 0;
    for (auto [l, r, mini, maxi] : vpii) {
      int mid = (l + r) / 2;

      auto calc = [&]() {
        ll ans = 0;

        int curin = 0, curl = 0, curr = n - 1;
        int tmpct = mid - (in - 1);
        while (curl != curr) {
          if (ct[2 * curin + 1] >= tmpct) {
            curin = curin * 2 + 1;
            curr = (curl + curr) / 2;
          } else {
            ans += sum[curin * 2 + 1];
            tmpct -= ct[curin * 2 + 1];
            curin = curin * 2 + 2;
            curl = (curl + curr) / 2 + 1;
          }
        }

        if (tmpct >= 1 && ct[curin] >= 1) {
          ans += sum[curin];
        }

        return ans;
      };

      while (in < maxi) {
        ll ans = calc();

        if (ans > dp[mid]) {
          dp[mid] = ans;
          to[mid] = in;
        }

        qin = loc[in];
        upd(0, 0, n - 1);
        in++;
      }

      ll ans = calc();

      if (ans > dp[mid]) {
        dp[mid] = ans;
        to[mid] = in;
      }
    }

    for (auto [l, r, mini, maxi] : vpii) {
      int mid = (l + r) / 2;

      if (l <= mid - 1) {
        vpii2.push_back({l, mid - 1, mini, to[mid]});
      }

      if (mid + 1 <= r) {
        vpii2.push_back({mid + 1, r, to[mid], maxi});
      }
    }

    vpii.swap(vpii2);
  }

  return dp;
}

vector<ll> calc2(vector<int> cur) {

  int n = cur.size();
  int maxd = 3 * n + 1;

  vector<pair<int, int>> helper;
  for (int i = 0; i < n; i++) {
    helper.emplace_back(cur[i], i);
  }

  sort(helper.rbegin(), helper.rend());

  vector<int> loc(n);
  for (int i = 0; i < n; i++)
    loc[helper[i].second] = i;

  vector<ll> dp(maxd + 1);
  vector<int> to(maxd + 1);

  vector<array<int, 4>> vpii;
  vpii.push_back({0, maxd, 0, n});
  while (!vpii.empty()) {
    vector<int> ct(4 * n);
    vector<ll> sum(4 * n);

    int qin;
    function<void(int, int, int)> upd = [&](int curin, int curl, int curr) {
      if (qin > curr || qin < curl)
        return;
      if (curl == curr) {
        ct[curin] = 1;
        sum[curin] = helper[curl].first;
        return;
      }

      upd(curin * 2 + 1, curl, (curl + curr) / 2);
      upd(curin * 2 + 2, (curl + curr) / 2 + 1, curr);

      ct[curin] = ct[curin * 2 + 1] + ct[curin * 2 + 2];
      sum[curin] = sum[curin * 2 + 1] + sum[curin * 2 + 2];
    };

    vector<array<int, 4>> vpii2;
    int in = 0;
    for (auto [l, r, mini, maxi] : vpii) {
      int mid = (l + r) / 2;

      auto calc = [&]() {
        ll ans = 0;

        int curin = 0, curl = 0, curr = n - 1;
        int tmpct = mid - 2 * (in - 1);
        while (curl != curr) {
          if (ct[2 * curin + 1] >= tmpct) {
            curin = curin * 2 + 1;
            curr = (curl + curr) / 2;
          } else {
            ans += sum[curin * 2 + 1];
            tmpct -= ct[2 * curin + 1];
            curin = curin * 2 + 2;
            curl = (curl + curr) / 2 + 1;
          }
        }

        if (tmpct >= 1 && ct[curin] >= 1) {
          ans += sum[curin];
        }

        return ans;
      };

      while (in < maxi) {
        ll ans = calc();

        if (ans > dp[mid]) {
          dp[mid] = ans;
          to[mid] = in;
        }

        qin = loc[in];
        upd(0, 0, n - 1);
        in++;
      }

      ll ans = calc();

      if (ans > dp[mid]) {
        dp[mid] = ans;
        to[mid] = in;
      }
    }

    for (auto [l, r, mini, maxi] : vpii) {
      int mid = (l + r) / 2;

      if (l <= mid - 1) {
        vpii2.push_back({l, mid - 1, mini, to[mid]});
      }

      if (mid + 1 <= r) {
        vpii2.push_back({mid + 1, r, to[mid], maxi});
      }
    }

    vpii.swap(vpii2);
  }

  return dp;
}

long long findMaxAttraction(int n, int start, int d, int attraction[]) {
  vector<int> right;
  for (int i = start; i < n; i++) {
    right.push_back(attraction[i]);
  }
  vector<int> left;
  left.push_back(0);
  for (int i = start - 1; i >= 0; i--) {
    left.push_back(attraction[i]);
  }

  vector<ll> gor = calc(right), backr = calc2(right), gol = calc(left),
             backl = calc2(left);

  ll ans = 0;
  for (int i = 0; i < min(d + 1, (int)gor.size()); i++) {
    ans = max(ans, gor[i] + backl[min(d - i, (int)backl.size() - 1)]);
  }
  for (int i = 0; i < min(d + 1, (int)gol.size()); i++) {
    ans = max(ans, gol[i] + backr[min(d - i, (int)backr.size() - 1)]);
  }

  return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...