#include "fish.h"
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define pii pair<int, int>
#define ff first
#define ss second
#define pb push_back
long long max_weights(int32_t n, int32_t m, vector<int32_t> X, vector<int32_t> Y,
vector<int32_t> W) {
for (int i = 0; i < m; i++) {
X[i]++; Y[i]++;
}
int olda[n + 1][n + 1];
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= n; j++) {
olda[i][j] = 0;
}
}
for (int i = 0; i < m; i++) {
olda[X[i]][Y[i]] = W[i];
}
int oldps[n + 1][n + 1]; // ps[i][j] = a[i][1] + ... + a[i][j]
for (int i = 1; i <= n; i++) {
oldps[i][0] = 0;
for (int j = 1; j <= n; j++) {
oldps[i][j] = oldps[i][j - 1] + olda[i][j];
}
}
vector<pii> a[n + 1], ps[n + 1];
set<pii> critpts;
for (int i = 1; i <= n; i++) {
a[i].pb({0, 0});
ps[i].pb({0, 0});
critpts.insert({i, 0});
}
for (int i = 0; i < m; i++) {
a[X[i]].pb({Y[i], W[i]});
// critical points: (X[i] +- 1, Y[i]), (X[i], Y[i] - 1)
if (X[i] > 1) {
critpts.insert({X[i] - 1, Y[i]});
}
if (X[i] < n) {
critpts.insert({X[i] + 1, Y[i]});
}
critpts.insert({X[i], Y[i] - 1});
}
for (int i = 1; i <= n; i++) {
sort(a[i].begin(), a[i].end());
for (pii &x : a[i]) {
ps[i].pb({x.ff, ps[i].back().ss + x.ss});
}
}
function<int(int, int)> GetPS = [&](int x, int y) -> int {
if (ps[x].empty() || y < ps[x].front().ff) return 0;
int lb = 0, rb = (int)(ps[x].size()) - 1;
while (lb < rb) {
int mid = (lb + rb + 1) >> 1;
if (ps[x][mid].ff <= y) lb = mid;
else rb = mid - 1;
}
return ps[x][lb].ss;
};
int olddp[2][n + 1][n + 1];
// olddp[increasing][pos][height] = max ans from 1 to pos with pos getting height, height must be increasing
for (int i = 0; i <= n; i++) {
olddp[0][0][i] = olddp[1][0][i] = (i == 0 ? 0 : -1e18);
olddp[0][1][i] = olddp[1][1][i] = 0;
}
int oldcur[n + 1];
function<void(void)> resetoldcur = [&]() {
for (int i = 0; i <= n; i++) oldcur[i] = 0;
};
vector<pii> dp[2][n + 1];
for (int i = 0; i <= n; i++) {
dp[0][0].pb({i, (i == 0 ? 0 : -1e18)});
dp[1][0].pb({i, (i == 0 ? 0 : -1e18)});
dp[0][1].pb({i, 0});
dp[1][1].pb({i, 0});
}
for (pii x : critpts) {
if (x.ff < 2) continue;
dp[0][x.ff].pb({x.ss, 0});
dp[1][x.ff].pb({x.ss, 0});
}
vector<pii> cur;
for (int i = 2; i <= n; i++) {
for (int j = 0; j <= n; j++) {
olddp[0][i][j] = olddp[1][i][j] = 0;
}
sort(dp[0][i].begin(), dp[0][i].end());
sort(dp[1][i].begin(), dp[1][i].end());
// Case 0: height of i is 0
resetoldcur();
olddp[0][i][0] = max(olddp[0][i - 1][0], olddp[1][i - 1][0]);
for (int j = 1; j <= n; j++) {
olddp[0][i][0] = max(olddp[0][i][0], max(olddp[0][i - 1][j], olddp[1][i - 1][j]) + oldps[i][j]);
}
olddp[1][i][0] = olddp[0][i][0];
// dp[flag][i][0].ff is always 0
for (int j = 0; j < (int)(dp[0][i - 1].size()); i++) {
dp[0][i][0].ss = max(dp[0][i][0].ss, dp[0][i - 1][j].ss + GetPS(i - 1, dp[0][i - 1][j].ff));
}
for (int j = 0; j < (int)(dp[1][i - 1].size()); i++) {
dp[0][i][0].ss = max(dp[0][i][0].ss, dp[1][i - 1][j].ss + GetPS(i - 1, dp[1][i - 1][j].ff));
}
dp[1][i][0].ss = dp[0][i][0].ss;
// Case 1: height of i-1 is 0
// Case 1.1: height of i-2 <= height of i
resetoldcur();
oldcur[0] = olddp[0][i - 2][0];
for (int j = 1; j <= n; j++) {
oldcur[j] = max(oldcur[j - 1], max(olddp[0][i - 2][j], olddp[1][i - 2][j]));
}
for (int j = 0; j <= n; j++) {
olddp[0][i][j] = max(olddp[0][i][j], oldcur[j] + oldps[i - 1][j]);
olddp[1][i][j] = max(olddp[1][i][j], oldcur[j] + oldps[i - 1][j]);
}
// Case 1.2: height of i-2 >= height of i
resetoldcur();
oldcur[n] = olddp[0][i - 2][n] + oldps[i - 1][n];
for (int j = n - 1; j >= 0; j--) {
oldcur[j] = max(oldcur[j + 1], max(olddp[0][i - 2][j], olddp[1][i - 2][j]) + oldps[i - 1][j]);
}
for (int j = 0; j <= n; j++) {
olddp[0][i][j] = max(olddp[0][i][j], oldcur[j]);
olddp[1][i][j] = max(olddp[1][i][j], oldcur[j]);
}
// now height of i-1 > 0
// Case 2: height of i-1 <= height of i
resetoldcur();
oldcur[1] = olddp[0][i - 1][1];
for (int j = 2; j <= n; j++) {
oldcur[j] = max(oldcur[j - 1] + olda[i - 1][j], olddp[0][i - 1][j]);
}
for (int j = 1; j <= n; j++) {
olddp[0][i][j] = max(olddp[0][i][j], oldcur[j]);
}
// Case 3: height of i-1 >= height of i
resetoldcur();
oldcur[n] = max(olddp[0][i - 1][n], olddp[1][i - 1][n]);
for (int j = n - 1; j >= 1; j--) {
oldcur[j] = max(oldcur[j + 1] + olda[i][j + 1], max(olddp[0][i - 1][j], olddp[1][i - 1][j]));
}
for (int j = 1; j <= n; j++) {
olddp[1][i][j] = max(olddp[1][i][j], oldcur[j]);
}
}
int ans = max(dp[0][n][0].ss, dp[1][n][0].ss);
for (int i = 1; i <= n; i++) {
// cerr << olddp[0][n][i] << ' ' << olddp[1][n][i] << endl;
ans = max(ans, max(olddp[0][n][i], olddp[1][n][i]));
}
return ans;
}
#undef int
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Runtime error |
963 ms |
2097152 KB |
Execution killed with signal 9 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Runtime error |
1 ms |
340 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Runtime error |
759 ms |
2097152 KB |
Execution killed with signal 9 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Runtime error |
1 ms |
340 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Runtime error |
1 ms |
340 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Runtime error |
1 ms |
340 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Runtime error |
759 ms |
2097152 KB |
Execution killed with signal 9 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Runtime error |
963 ms |
2097152 KB |
Execution killed with signal 9 |
2 |
Halted |
0 ms |
0 KB |
- |