Submission #1362904

#TimeUsernameProblemLanguageResultExecution timeMemory
1362904ereringCatfish Farm (IOI22_fish)C++20
0 / 100
1095 ms10388 KiB
#include <bits/stdc++.h>
using namespace std;

long long max_weights(int N, int M, vector<int> X, vector<int> Y, vector<int> W) {
    const long long NEG = -(1LL << 60);

    vector<vector<pair<int, int>>> fish(N);
    for (int i = 0; i < M; ++i) {
        if (0 <= X[i] && X[i] < N && 0 <= Y[i] && Y[i] < N) {
            fish[Y[i]].push_back({X[i], W[i]});
        }
    }

    for (int c = 0; c < N; ++c) {
        sort(fish[c].begin(), fish[c].end());
    }

    auto calc = [&](int col, int l, int m, int r) -> long long {
        if (col < 0 || col >= N) return 0;
        int mx = max(l, r);
        if (mx <= m) return 0;

        long long res = 0;
        for (auto [row, wt] : fish[col]) {
            if (m <= row && row < mx) res += wt;
        }
        return res;
    };

    if (N <= 300) {
        vector<long long> dp(N + 1, 0), ndp(N + 1, NEG);

        for (int h1 = 0; h1 <= N; ++h1) {
            dp[h1] = calc(0, 0, 0, h1);
        }

        for (int c = 1; c < N; ++c) {
            fill(ndp.begin(), ndp.end(), NEG);
            for (int prev = 0; prev <= N; ++prev) {
                if (dp[prev] <= NEG / 2) continue;
                for (int cur = 0; cur <= N; ++cur) {
                    long long add = calc(c, prev, cur, 0);
                    ndp[cur] = max(ndp[cur], dp[prev] + add);
                }
            }
            dp.swap(ndp);
        }

        return *max_element(dp.begin(), dp.end());
    }

    vector<int> vals;
    vals.reserve(M + 2);
    vals.push_back(0);
    vals.push_back(N);
    for (int x : X) {
        vals.push_back(x);
        if (x + 1 <= N) vals.push_back(x + 1);
    }
    sort(vals.begin(), vals.end());
    vals.erase(unique(vals.begin(), vals.end()), vals.end());

    int K = (int)vals.size();

    vector<long long> dp(K, NEG), ndp(K, NEG);

    for (int j = 0; j < K; ++j) {
        dp[j] = calc(0, 0, 0, vals[j]);
    }

    for (int c = 1; c < N; ++c) {
        fill(ndp.begin(), ndp.end(), NEG);

        for (int pj = 0; pj < K; ++pj) {
            if (dp[pj] <= NEG / 2) continue;
            int prev = vals[pj];

            for (int cj = 0; cj < K; ++cj) {
                int cur = vals[cj];
                long long add = calc(c, prev, cur, 0);
                ndp[cj] = max(ndp[cj], dp[pj] + add);
            }
        }

        dp.swap(ndp);
    }

    return *max_element(dp.begin(), dp.end());
}
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...