#include "fish.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
ll max_weights(int N, int M, vector<int> X, vector<int> Y, vector<int> W) {
vector<vector<int>> fish(N);
for (int i = 0; i < M; i++) {
fish[X[i]].push_back(i);
}
vector<vector<ll>> pfs(N);
for (int i = 0; i < N; i++) {
sort(fish[i].begin(), fish[i].end(), [&](int x, int y) {return Y[x] < Y[y];});
pfs[i].resize(fish[i].size() + 1);
for (int j = 0; j < fish[i].size(); j++) {
pfs[i][j + 1] = pfs[i][j] + W[fish[i][j]];
}
}
auto get_sum = [&](int c, int lo, int hi) {
if (c == -1 || c >= N) {
return 0LL;
}
int x = lower_bound(fish[c].begin(), fish[c].end(), lo, [&](int i, int t) {return Y[i] < t;}) - fish[c].begin();
int y = lower_bound(fish[c].begin(), fish[c].end(), hi, [&](int i, int t) {return Y[i] < t;}) - fish[c].begin();
return pfs[c][y] - pfs[c][x];
};
vector<ll> dp0{0}, dp1{0};
vector<int> h{-1};
for (int i = 0; i < N; i++) {
vector<int> cand{-1};
if (i) {
for (int j = 0; j < fish[i - 1].size(); j++) {
cand.push_back(Y[fish[i - 1][j]]);
}
}
if (i < N - 1) {
for (int j = 0; j < fish[i + 1].size(); j++) {
cand.push_back(Y[fish[i + 1][j]]);
}
}
sort(cand.begin(), cand.end());
vector<ll> pmax(h.size() + 1);
for (int j = 0; j < h.size(); j++) {
pmax[j + 1] = max(pmax[j], dp0[j] + get_sum(i - 1, h[j] + 1, N));
}
vector<ll> smax(h.size() + 1);
for (int j = h.size() - 1; j >= 0; j--) {
smax[j] = max(smax[j + 1], dp1[j] + get_sum(i, 0, h[j] + 1));
}
vector<ll> pd0(cand.size());
vector<ll> pd1(cand.size());
pd1[0] = smax[0];
for (int j = 0; j < h.size(); j++) {
pd0[0] = max(pd0[0], dp1[j]);
}
for (int j = 1; j < cand.size(); j++) {
int k = upper_bound(h.begin(), h.end(), cand[j]) - h.begin();
pd0[j] = max(dp1[0], pmax[k] - get_sum(i - 1, cand[j] + 1, N));
pd1[j] = max(pd0[j], smax[k] - get_sum(i, 0, cand[j] + 1));
}
swap(dp0, pd0);
swap(dp1, pd1);
swap(h, cand);
}
ll ans = 0;
for (auto &v : dp1) {
ans = max(ans, v);
}
return ans;
}