Submission #812489

#TimeUsernameProblemLanguageResultExecution timeMemory
812489_martynasCatfish Farm (IOI22_fish)C++17
100 / 100
795 ms93652 KiB
#include "fish.h"

#include <bits/stdc++.h>

using namespace std;

#define pb push_back

using ll = long long;
using pii = pair<int, int>;

const ll inf = 1e16;

int n, m;
map<pii, ll> F;

// ll max_weights_y0(int N, int M, vector<int> X, vector<int> Y, vector<int> W) {
//     n = N, m = M;
//     vector<ll> A(n);
//     for(int i = 0; i < m; i++) A[X[i]] = W[i];
//     auto get = [&](int i) {
//         return (i < 0 || i >= n ? 0 : A[i]);
//     };
//     ll mx1 = 0, mx2 = 0, mx3 = 0;
//     for(int i = 0; i < n; i++) {
//         ll nmx1, nmx2, nmx3;
//         nmx3 = max(mx3, mx2);
//         nmx2 = max(mx2, mx1);
//         nmx1 = max(mx1, mx3+get(i-1)+get(i+1));
//         nmx1 = max(nmx1, mx2+get(i+1));
//         nmx1 = max(nmx1, mx1-get(i)+get(i+1));
//         mx1 = nmx1, mx2 = nmx2, mx3 = nmx3;
//     }
//     return max({mx1, mx2, mx3});
// }

ll max_weights(int N, int M, vector<int> X, vector<int> Y, vector<int> W) {
    //return max_weights_y0(N, M, X, Y, W);
    n = N, m = M;
    vector<vector<int>> pos(n);
    for(int i = 0; i < m; i++) {
        if(X[i] > 0) pos[X[i]-1].pb(Y[i]);
        if(X[i]+1 < n) pos[X[i]+1].pb(Y[i]);
        F[{X[i], Y[i]}] = W[i];
    }
    for(int i = 0; i < n; i++) {
        if(pos[i].empty()) continue;
        sort(pos[i].begin(), pos[i].end());
        pos[i].erase(unique(pos[i].begin(), pos[i].end()), pos[i].end());
    }
    // printf("interesting positions:\n");
    // for(int i = 0; i < n; i++) {
    //     printf("i = %d\n", i);
    //     for(int j : pos[i]) {
    //         printf("(%d, %d) ", i, j);
    //     }
    //     printf("\n");
    // }
    vector<vector<ll>> dp[2]; // 0-inc, 1-dec
    dp[0] = dp[1] = vector<vector<ll>>(n);
    for(int k = 0; k < 2; k++) {
        for(int i = 0; i < n; i++) {
            dp[k][i] = vector<ll>(pos[i].size(), -inf);
        }
    }
    ll mx_with = 0, mx_wo = 0; // 'with' as in with the fishes on the right
    for(int i = 0; i < n; i++) {
        ll sum;
        mx_wo = max(mx_wo, mx_with);
        if(i > 1) for(int j = 0; j < (int)pos[i-2].size(); j++) {
            mx_wo = max(mx_wo, dp[0][i-2][j]);
            mx_wo = max(mx_wo, dp[1][i-2][j]);
        }
        sum = 0;
        if(i > 1) for(int j = 0; j < (int)pos[i-2].size(); j++) {
            sum += F[pii{i-1, pos[i-2][j]}];
            mx_with = max(mx_with, dp[0][i-2][j]+sum);
            mx_with = max(mx_with, dp[1][i-2][j]+sum);
        }
        // try starting a new seq
        sum = 0;
        for(int j = 0; j < (int)pos[i].size(); j++) {
            sum += F[pii{i-1, pos[i][j]}];
            dp[0][i][j] = max(dp[0][i][j], mx_wo+sum);
            dp[0][i][j] = max(dp[0][i][j], mx_with);
        }
        // inc to inc
        ll mx = -inf;
        if(i > 0) for(int j = 0, k = 0; j < (int)pos[i].size(); j++) {
            while(k < (int)pos[i-1].size() && pos[i-1][k] < pos[i][j]) {
                mx = max(mx, dp[0][i-1][k]);
                k++;
            }
            mx += F[pii{i-1, pos[i][j]}];
            dp[0][i][j] = max(dp[0][i][j], mx);
        }
        mx = -inf;
        // inc/dec to dec
        if(i > 0) for(int j = (int)pos[i].size()-1, k = (int)pos[i-1].size()-1;
            j >= 0; j--) {
            ll delayed = 0;
            while(k >= 0 && pos[i-1][k] >= pos[i][j]) {
                mx = max(mx, dp[0][i-1][k]);
                mx = max(mx, dp[1][i-1][k]);
                if(pos[i-1][k] != pos[i][j]) mx += F[pii{i, pos[i-1][k]}];
                else delayed += F[pii{i, pos[i-1][k]}];
                k--;
            }
            dp[1][i][j] = max(dp[1][i][j], mx);
            mx += delayed;
        }
    }
    ll ans = 0;
    for(int i = 0; i < n; i++) {
        ll sum = 0;
        for(int j = 0; j < (int)pos[i].size(); j++) {
            sum += F[pii{i+1, pos[i][j]}];
            // printf("sum = %lld\n", sum);
            // printf("dp[%d][%d][%d] = %lld\n", 0, i, j, dp[0][i][j]);
            // printf("dp[%d][%d][%d] = %lld\n", 1, i, j, dp[1][i][j]);
            ans = max(ans, dp[0][i][j]+sum);
            ans = max(ans, dp[1][i][j]+sum);
        }
    }
    //cout << ans << "!\n";
    return ans;
}

// int main() {
//     int N, M;
//     vector<int> X, Y, W;
//     cin >> N >> M;
//     for(int i = 0; i < M; i++) {
//         int x, y, w; cin >> x >> y >> w;
//         X.pb(x), Y.pb(y), W.pb(w);
//     }
//     ll ans1 = max_weights(N, M, X, Y, W);
//     ll ans2 = max_weights_y0(N, M, X, Y, W);
//     cerr << ans1 << " " << ans2 << "\n";
//     assert(ans1 == ans2);
// }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...