Submission #1317217

#TimeUsernameProblemLanguageResultExecution timeMemory
1317217starplatinumHoliday (IOI14_holiday)C++20
100 / 100
298 ms42972 KiB
#include <algorithm>
#include <climits>
#include <cstdint>
#include <vector>

#ifdef LOCAL
#include <iostream>
#include <random>
#endif

using namespace std;

namespace {

struct PersistentSegTree {
    // Segment tree over compressed values [0..m-1]
    int m = 0;                 // number of distinct values
    int node_cnt = 0;          // allocated nodes
    vector<int> lch, rch;
    vector<int> cnt;
    vector<long long> sum;
    vector<long long> values;  // original values by compressed index

    PersistentSegTree() = default;

    void init(const vector<long long>& sorted_unique_vals, int n_elements) {
        values = sorted_unique_vals;
        m = (int)values.size();

        int lg = 0;
        while ((1 << lg) < max(1, m)) ++lg;
        int cap = (n_elements + 5) * (lg + 3);
        cap = max(cap, 32);

        lch.assign(cap, 0);
        rch.assign(cap, 0);
        cnt.assign(cap, 0);
        sum.assign(cap, 0);
        node_cnt = 0; // node 0 is the null node
    }

    int new_node(int from) {
        int idx = ++node_cnt;
        if (idx >= (int)cnt.size()) {
            int new_cap = (int)cnt.size() * 2;
            lch.resize(new_cap);
            rch.resize(new_cap);
            cnt.resize(new_cap);
            sum.resize(new_cap);
        }
        lch[idx] = lch[from];
        rch[idx] = rch[from];
        cnt[idx] = cnt[from];
        sum[idx] = sum[from];
        return idx;
    }

    int update(int prev, int segL, int segR, int pos, long long val) {
        int cur = new_node(prev);
        cnt[cur] += 1;
        sum[cur] += val;
        if (segL == segR) return cur;

        int mid = (segL + segR) >> 1;
        if (pos <= mid) {
            lch[cur] = update(lch[prev], segL, mid, pos, val);
        } else {
            rch[cur] = update(rch[prev], mid + 1, segR, pos, val);
        }
        return cur;
    }

    long long query_top_k(int leftRoot, int rightRoot, int segL, int segR, int k) const {
        if (k <= 0) return 0LL;
        int cntDiff = cnt[rightRoot] - cnt[leftRoot];
        if (cntDiff <= 0) return 0LL;

        if (segL == segR) {
            int take = min(k, cntDiff);
            return (long long)take * values[segL];
        }

        int mid = (segL + segR) >> 1;
        int lL = lch[leftRoot], lR = lch[rightRoot];
        int rL = rch[leftRoot], rR = rch[rightRoot];

        int cntRight = cnt[rR] - cnt[rL];
        long long sumRight = sum[rR] - sum[rL];

        if (k <= cntRight) {
            return query_top_k(rL, rR, mid + 1, segR, k);
        }
        return sumRight + query_top_k(lL, lR, segL, mid, k - cntRight);
    }
};

struct OneDirectionSolver {
    int n = 0;
    int s = 0;
    int d = 0;
    vector<int> a;

    vector<long long> comp_vals; // sorted unique values
    vector<int> roots;           // prefix roots, size n+1
    PersistentSegTree pst;

    long long best = 0;

    int jMax(int i) const {
        long long jm = (long long)d - (long long)s + 2LL * i;
        if (jm < s) jm = s;
        if (jm > n - 1) jm = n - 1;
        return (int)jm;
    }

    long long eval(int i, int j) const {
        long long moves = (long long)s + (long long)j - 2LL * i;
        if (moves > d) return LLONG_MIN / 4;

        int k = (int)((long long)d - moves);
        int len = j - i + 1;
        if (k > len) k = len;
        if (k <= 0) return 0LL;

        int leftRoot = roots[i];
        int rightRoot = roots[j + 1];
        return pst.query_top_k(leftRoot, rightRoot, 0, pst.m - 1, k);
    }

    void solve(int iL, int iR, int jL, int jR) {
        if (iL > iR) return;
        int mid = (iL + iR) >> 1;

        int upper = min(jR, jMax(mid));
        int bestJ = jL;
        long long bestVal = LLONG_MIN / 4;

        for (int j = jL; j <= upper; ++j) {
            long long val = eval(mid, j);
            if (val > bestVal) {
                bestVal = val;
                bestJ = j;
            }
        }

        if (bestVal > best) best = bestVal;
        solve(iL, mid - 1, jL, bestJ);
        solve(mid + 1, iR, bestJ, jR);
    }

    long long run(const vector<int>& arr, int start, int days) {
        a = arr;
        n = (int)a.size();
        s = start;
        d = days;
        best = 0;

        // Coordinate compression.
        comp_vals.assign(n, 0);
        for (int i = 0; i < n; ++i) comp_vals[i] = (long long)a[i];
        sort(comp_vals.begin(), comp_vals.end());
        comp_vals.erase(unique(comp_vals.begin(), comp_vals.end()), comp_vals.end());

        pst.init(comp_vals, n);

        // Build persistent prefix roots: roots[t] is multiset of a[0..t-1].
        roots.assign(n + 1, 0);
        for (int i = 0; i < n; ++i) {
            int pos = (int)(lower_bound(comp_vals.begin(), comp_vals.end(), (long long)a[i]) - comp_vals.begin());
            roots[i + 1] = pst.update(roots[i], 0, pst.m - 1, pos, (long long)a[i]);
        }

        int iMin = max(0, s - d / 2);
        int iMax = s;
        if (iMin > iMax) return 0LL;

        int jL = s;
        int jR = jMax(s);
        if (jL > jR) return 0LL;

        solve(iMin, iMax, jL, jR);
        return best;
    }
};

} // namespace

long long int findMaxAttraction(int n, int start, int d, int attraction[]) {
    vector<int> a(n);
    for (int i = 0; i < n; ++i) a[i] = attraction[i];

    OneDirectionSolver solver;
    long long ans1 = solver.run(a, start, d);

    // Reverse for the opposite turning direction.
    vector<int> rev(a.rbegin(), a.rend());
    int start_rev = n - 1 - start;
    long long ans2 = solver.run(rev, start_rev, d);

    return max(ans1, ans2);
}

#ifdef LOCAL
static long long brute_solve(const vector<int>& a, int s, int d) {
    int n = (int)a.size();
    long long best = 0;
    for (int L = 0; L <= s; ++L) {
        for (int R = s; R < n; ++R) {
            int move = (R - L) + min(s - L, R - s);
            if (move > d) continue;
            int k = d - move;
            int len = R - L + 1;
            if (k > len) k = len;
            vector<int> v(a.begin() + L, a.begin() + R + 1);
            sort(v.begin(), v.end(), greater<int>());
            long long sum = 0;
            for (int i = 0; i < k; ++i) sum += v[i];
            best = max(best, sum);
        }
    }
    return best;
}

int main() {
    std::mt19937 rng(1);
    for (int t = 0; t < 2000; ++t) {
        int n = std::uniform_int_distribution<int>(2, 10)(rng);
        vector<int> a(n);
        for (int i = 0; i < n; ++i) a[i] = std::uniform_int_distribution<int>(0, 20)(rng);
        int s = std::uniform_int_distribution<int>(0, n - 1)(rng);
        int d = std::uniform_int_distribution<int>(0, 2 * n + n / 2)(rng);

        long long fast = findMaxAttraction(n, s, d, a.data());
        long long brute = brute_solve(a, s, d);
        if (fast != brute) {
            cerr << "Mismatch!\n";
            cerr << "n=" << n << " s=" << s << " d=" << d << "\n";
            cerr << "a: ";
            for (int x : a) cerr << x << ' ';
            cerr << "\nfast=" << fast << " brute=" << brute << "\n";
            return 0;
        }
    }
    cerr << "All tests passed.\n";
}
#endif
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...