Submission #915340

#TimeUsernameProblemLanguageResultExecution timeMemory
915340vjudge1Shortcut (IOI16_shortcut)C++17
97 / 100
2013 ms100264 KiB
#include "shortcut.h"

#include <bits/stdc++.h>
using namespace std;

long long find_shortcut(int n, std::vector<int> l, std::vector<int> d, int c) {
        vector<int64_t> a(n), b(n);  // a = x + d, b = x - d
        vector<int64_t> x(n);
        for (int i = 1; i < n; i++) x[i] = x[i - 1] + l[i - 1];
        for (int i = 0; i < n; i++) a[i] = x[i] + d[i], b[i] = x[i] - d[i];
        vector<int64_t> dx = x;
        for (auto& v : dx) v *= 2;

        // abs(x[i] - x[p]) + abs(x[j] - x[q]) + d[i] + d[j] + c <= target
        // x[i] - x[p] + x[j] - x[q] + d[i] + d[j] + c <= target
        // x[p] - x[i] + x[j] - x[q] + d[i] + d[j] + c <= target
        // x[i] - x[p] + x[q] - x[j] + d[i] + d[j] + c <= target
        // x[p] - x[i] + x[q] - x[j] + d[i] + d[j] + c <= target

        // x[p] + x[q] >= x[i] + x[j] + d[i] + d[j] + c - target
        // x[p] + x[q] <= x[i] + x[j] - d[i] - d[j] - c + target
        // x[q] - x[p] <= x[j] - x[i] - d[i] - d[j] - c + target
        // x[q] - q[p] >= x[j] - x[i] + d[i] + d[j] + c - target

        // k := c - target
        // x[p] + x[q] >= (x[i] + d[i]) + (x[j] + d[j]) + k
        // x[p] + x[q] <= (x[i] - d[i]) + (x[j] - d[j]) - k
        // x[q] - x[p] <= (x[j] - d[j]) - (x[i] + d[i]) - k
        // x[q] - x[p] >= (x[j] + d[j]) - (x[i] - d[i]) + k

        // x[p] + x[q] >= a[i] + a[j] + k
        // x[p] + x[q] <= b[i] + b[j] - k
        // x[q] - x[p] <= b[j] - a[i] - k
        // x[q] - x[p] >= a[j] - b[i] + k

        vector<int> orda(n);
        iota(orda.begin(), orda.end(), 0);
        sort(orda.begin(), orda.end(), [&](int i, int j) { return a[i] < a[j]; });
        vector<int> ordb(n);
        iota(ordb.begin(), ordb.end(), 0);
        sort(ordb.begin(), ordb.end(), [&](int i, int j) { return b[i] < b[j]; });

        auto check = [&](int64_t target) {
                int64_t low_sum = -1e18, high_sum = +1e18;
                int64_t low_diff = -1e18, high_diff = +1e18;
                int64_t k = c - target;
                vector<vector<int>> ready(n);
                int i = 0;

                pair<int64_t, int64_t> max_a(-1, -1), min_a(-1, -1), max_b(-1, -1), min_b(-1, -1);

                auto update = [&](int j) {
                        if (max_a.first == -1 || a[j] >= a[max_a.first]) {
                                swap(max_a.first, max_a.second);
                                max_a.first = j;
                        } else if (max_a.second == -1 || a[j] > a[max_a.second]) {
                                max_a.second = j;
                        }

                        if (min_a.first == -1 || a[j] <= a[min_a.first]) {
                                swap(min_a.first, min_a.second);
                                min_a.first = j;
                        } else if (min_a.second == -1 || a[j] < a[min_a.second]) {
                                min_a.second = j;
                        }

                        if (max_b.first == -1 || b[j] >= b[max_b.first]) {
                                swap(max_b.first, max_b.second);
                                max_b.first = j;
                        } else if (max_b.second == -1 || b[j] > b[max_b.second]) {
                                max_b.second = j;
                        }

                        if (min_b.first == -1 || b[j] <= b[min_b.first]) {
                                swap(min_b.first, min_b.second);
                                min_b.first = j;
                        } else if (min_b.second == -1 || b[j] < b[min_b.second]) {
                                min_b.second = j;
                        }
                };

                auto get = [&](const vector<int64_t>& o, pair<int64_t, int64_t>& v, int j, const int64_t& inf) {
                        if (v.first == -1) return inf;
                        if (v.first == j) return v.second == -1 ? inf : o[v.second];
                        return o[v.first];
                };

                for (int j : orda) {
                        while (i < n && a[j] - b[ordb[i]] > target) update(ordb[i++]);
                        low_sum = max(low_sum, a[j] + get(a, max_a, j, -1e18) + k);
                        high_sum = min(high_sum, b[j] + get(b, min_b, j, +1e18) - k);
                        low_diff = max(low_diff, a[j] - get(b, min_b, j, +1e18) + k);
                        high_diff = min(high_diff, b[j] - get(a, max_a, j, -1e18) - k);
                }

                if (low_sum > high_sum || low_diff > high_diff) return 0;
                int low_sum_q = n - 1, low_diff_q = 0, high_sum_q = n - 1, high_diff_q = 0;
                for (int p = 0; p < n; p++) {
                        low_diff_q = max(low_diff_q, p + 1);
                        while (low_sum_q > p && x[p] + x[low_sum_q] >= low_sum) low_sum_q--;
                        while (high_sum_q >= 0 && x[p] + x[high_sum_q] > high_sum) high_sum_q--;
                        // (low_sum_q, high_sum_q]
                        while (low_diff_q < n && x[low_diff_q] - x[p] < low_diff) low_diff_q++;
                        while (high_diff_q < n && x[high_diff_q] - x[p] <= high_diff) high_diff_q++;
                        // [low_diff_q, high_diff_q)

                        low_sum_q = max(low_sum_q, p);
                        if (low_sum_q >= high_sum_q || low_diff_q >= high_diff_q) continue;
                        int max_l = max(low_sum_q + 1, low_diff_q);
                        int min_r = min(high_sum_q, high_diff_q - 1);
                        if (max_l <= min_r) {
                                return 1;
                        }
                }

                return 0;
        };

        int64_t low = 0, high = 1ll << 60;
        while (low < high) {
                int64_t mid = low + high >> 1;
                if (check(mid)) {
                        high = mid;
                } else {
                        low = mid + 1;
                }
        }
        return high;
}

Compilation message (stderr)

shortcut.cpp: In function 'long long int find_shortcut(int, std::vector<int>, std::vector<int>, int)':
shortcut.cpp:121:35: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
  121 |                 int64_t mid = low + high >> 1;
      |                               ~~~~^~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...