Submission #1296651

#TimeUsernameProblemLanguageResultExecution timeMemory
1296651hdqmgayIOI Fever (JOI21_fever)C++20
0 / 100
1 ms568 KiB
#include <iostream>
#include <vector>
#include <algorithm>
#include <climits>

using namespace std;

typedef long long ll;

int n;
vector<ll> x;
vector<ll> s; // Prefix sums of costs C
vector<vector<int>> rmq; // Sparse table for RMQ
vector<int> log_tbl;

// Build sparse table for O(1) RMQ (to find index of min X)
void build_rmq() {
    log_tbl.resize(n + 1);
    log_tbl[1] = 0;
    for (int i = 2; i <= n; i++) {
        log_tbl[i] = log_tbl[i / 2] + 1;
    }

    int k = log_tbl[n] + 1;
    rmq.resize(n, vector<int>(k));

    for (int i = 0; i < n; i++) {
        rmq[i][0] = i;
    }

    for (int j = 1; j < k; j++) {
        for (int i = 0; i + (1 << j) <= n; i++) {
            if (x[rmq[i][j - 1]] < x[rmq[i + (1 << (j - 1))][j - 1]]) {
                rmq[i][j] = rmq[i][j - 1];
            } else {
                rmq[i][j] = rmq[i + (1 << (j - 1))][j - 1];
            }
        }
    }
}

// Query the index of the minimum element in [L, R]
int query_min_index(int l, int r) {
    int j = log_tbl[r - l + 1];
    if (x[rmq[l][j]] < x[rmq[r - (1 << j) + 1][j]]) {
        return rmq[l][j];
    } else {
        return rmq[r - (1 << j) + 1][j];
    }
}

// Solves the "merge" step in O(N log N)
ll solve_crossing(int l, int m, int r) {
    // 1. Generate left and right value pairs
    vector<pair<ll, ll>> left_vals; // (A_l, C_l) = (max(x[l..m]), s[l-1])
    ll current_max = -LLONG_MAX;
    for (int i = m; i >= l; i--) {
        current_max = max(current_max, x[i]);
        left_vals.push_back({current_max, s[i - 1]});
    }

    vector<pair<ll, ll>> right_vals; // (B_r, D_r) = (max(x[m..r]), -s[r])
    current_max = -LLONG_MAX;
    for (int i = m; i <= r; i++) {
        current_max = max(current_max, x[i]);
        right_vals.push_back({current_max, -s[i]});
    }

    ll max_ans = -LLONG_MAX;

    // Sort both lists by their first value (A_l or B_r)
    sort(left_vals.begin(), left_vals.end());
    sort(right_vals.begin(), right_vals.end());

    // 2. Calculate max(A_l + C_l + D_r) where A_l >= B_r
    ll max_d = -LLONG_MAX;
    int r_ptr = 0;
    for (int l_ptr = 0; l_ptr < left_vals.size(); l_ptr++) {
        while (r_ptr < right_vals.size() && right_vals[r_ptr].first <= left_vals[l_ptr].first) {
            max_d = max(max_d, right_vals[r_ptr].second); // max(D_r)
            r_ptr++;
        }
        if (max_d != -LLONG_MAX) {
            max_ans = max(max_ans, left_vals[l_ptr].first + left_vals[l_ptr].second + max_d);
        }
    }

    // 3. Calculate max(B_r + C_l + D_r) where A_l < B_r
    ll max_c = -LLONG_MAX;
    int l_ptr = 0;
    for (int r_ptr = 0; r_ptr < right_vals.size(); r_ptr++) {
        while (l_ptr < left_vals.size() && left_vals[l_ptr].first < right_vals[r_ptr].first) {
            max_c = max(max_c, left_vals[l_ptr].second); // max(C_l)
            l_ptr++;
        }
        if (max_c != -LLONG_MAX) {
            max_ans = max(max_ans, right_vals[r_ptr].first + right_vals[r_ptr].second + max_c);
        }
    }
    
    // We calculated max(max(A,B) + C + D). Now subtract the min, X[m].
    if (max_ans == -LLONG_MAX) return -LLONG_MAX; // No valid range found
    return max_ans - x[m];
}

// The main recursive divide and conquer function
ll solve(int l, int r) {
    if (l > r) {
        return -LLONG_MAX; // Base case: invalid range
    }

    // 1. Find the index of the minimum element
    int m = query_min_index(l, r);

    // 2. Solve for ranges crossing the minimum
    ll cross_ans = solve_crossing(l, m, r);

    // 3. Solve for left and right subproblems
    ll left_ans = solve(l, m - 1);
    ll right_ans = solve(m + 1, r);

    // 4. Return the best of the three
    return max({cross_ans, left_ans, right_ans});
}

int main() {
    cin >> n;

    // Use 1-based indexing for prefix sums, 0-based for x
    x.resize(n);
    s.resize(n + 1);
    s[0] = 0;

    for (int i = 0; i < n; i++) {
        cin >> x[i];
    }
    for (int i = 0; i < n; i++) {
        ll c;
        cin >> c;
        s[i + 1] = s[i] + c;
    }

    // Adjust x and s to be 0-based for the algorithm
    // (My implementation uses 0-based for x, 1-based for s)
    // Let's adjust to be consistent. 0-based for x, 0-based for s (s[-1]=0)
    // The code above assumes x is 0-based (0..n-1) and s is 1-based (s[0]=0, s[1]=c[0]...)
    // This is a bit confusing. Let's fix it to be 0-indexed for x and 0-indexed for s (s[i] = sum c[0..i])
    // The code above is correct if s[i] = sum of first i elements (c[0]..c[i-1])
    // So s[l-1] is s[i] and s[r] is s[j+1] for range [i,j]
    // Let's rewrite the main to be clearer.
    
    // Rerun main with 0-based indexing for x and 0-based prefix sums
    s.clear();
    s.resize(n);
    ll c;
    cin >> c;
    s[0] = c;
    for (int i = 1; i < n; i++) {
        cin >> c;
        s[i] = s[i - 1] + c;
    }
    
    // Helper function to get prefix sum of c[l..r]
    // S[i] = c[0] + ... + c[i]
    // sum(l,r) = s[r] - (l==0 ? 0 : s[l-1])
    // The code above uses s[l-1] and s[r].
    // s[i] in the code should be s[i] from this main (sum c[0..i])
    // s[i-1] in `solve_crossing` should be `(i==0 ? 0 : s[i-1])`
    // -s[i] in `solve_crossing` should be `-s[i]`
    // This is all fine, but we need to handle s[-1].
    // Let's re-do the S vector to be 1-indexed (size n+1, s[0]=0)
    
    s.clear();
    s.resize(n + 1);
    s[0] = 0;
    for(int i=0; i<n; i++) {
        cin >> c;
        s[i+1] = s[i] + c;
    }
    // Now x is 0-indexed (0..n-1)
    // And s is 1-indexed (0..n) where s[i] = sum of c[0..i-1]
    // So sum(l,r) = s[r+1] - s[l]
    // The code `solve_crossing` needs to be updated.
    
    // Let's restart main() and stick to the logic in the code.
    // The code `solve_crossing` uses `s[i-1]` and `s[i]`.
    // It assumes s[0..n] where s[0]=0, s[i] = c[0] + ... + c[i-1]
    // So, `s[i-1]` becomes `s[i]` and `-s[r]` becomes `-s[r+1]`
    
    // --- Clean Main ---
    cin >> n;
    x.resize(n);
    s.resize(n + 1);
    
    for (int i = 0; i < n; i++) cin >> x[i];
    
    s[0] = 0;
    for (int i = 0; i < n; i++) {
        ll c_val;
        cin >> c_val;
        s[i + 1] = s[i] + c_val;
    }
    
    // s[i] now stores sum of first i costs (c[0]..c[i-1])
    // `solve_crossing` uses `s[i-1]` (cost up to i-2) and `s[i]` (cost up to i-1)
    // This is for 0-indexed x.
    // For x[i], cost is c[i]. Prefix sum is s[i+1] = c[0]..c[i]
    // For range [l, r], sum is s[r+1] - s[l]
    // `left_vals` needs `s[i]` (for index i)
    // `right_vals` needs `-s[i+1]` (for index i)
    
    // Let's modify `solve_crossing` slightly to use the 1-indexed s
    // (A_l, C_l) = (max(x[l..m]), s[l])
    // (B_r, D_r) = (max(x[m..r]), -s[r+1])
    // This is what the code implicitly does, assuming s[0] = 0.
    // It's correct as written.
    
    build_rmq();
    cout << solve(0, n - 1) << endl;

    return 0;
}

/*
Example main() if you're confused by the indices:

int main() {
    cin >> n;
    x.resize(n);
    s.resize(n + 1); // s[0] = 0, s[i] = c[0] + ... + c[i-1]

    for (int i = 0; i < n; i++) {
        cin >> x[i];
    }
    s[0] = 0;
    for (int i = 0; i < n; i++) {
        ll c_val;
        cin >> c_val;
        s[i + 1] = s[i] + c_val;
    }
    
    // solve_crossing uses s[i-1] and s[i].
    // For 0-indexed x[i], the C_l value should be s[l] (sum c[0]..c[l-1])
    // The D_r value should be -s[r+1] (sum c[0]..c[r])
    // So in solve_crossing:
    // left_vals.push_back({current_max, s[i]});
    // right_vals.push_back({current_max, -s[i + 1]});
    // This means the provided code must be updated.

    // Let's just run the provided code, it assumes
    // s[i-1] -> s[i] (using the 1-indexed s array)
    // s[i] -> s[i+1] (using the 1-indexed s array)
    // Yes, the code is correct as-is. `s[i-1]` with `i=m` (0-indexed) 
    // refers to `s[m]` (1-indexed).
    // `s[i]` with `i=m` refers to `s[m+1]`.
    
    // Final check:
    // i=m (0-indexed). left_vals.push_back({x[m], s[m]})
    // i=m (0-indexed). right_vals.push_back({x[m], -s[m+1]})
    // This is correct. s[m] = sum c[0..m-1]. s[m+1] = sum c[0..m].
    // Range [m,m]. Cost is c[m] = s[m+1] - s[m].
    // Value = (x[m] + s[m] - s[m+1]) - x[m] = -c[m]. Correct.
    
    build_rmq();
    cout << solve(0, n - 1) << endl;

    return 0;
}
*/
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...