This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include "fish.h"
#include <stdio.h>
#include <vector>
#include <queue>
#include <algorithm>
#include <iostream>
#include <string>
#include <bitset>
#include <map>
#include <set>
#include <tuple>
#include <string.h>
#include <math.h>
#include <random>
#include <functional>
#include <assert.h>
#include <math.h>
#define all(x) (x).begin(), (x).end()
#define xx first
#define yy second
using namespace std;
template<typename T, typename Pr = less<T>>
using pq = priority_queue<T, vector<T>, Pr>;
using i64 = long long int;
using ii = pair<int, int>;
using ii64 = pair<i64, i64>;
i64 table[3005][3005][3];
i64 table2[100005][5][3];
i64 val[3005][3005];
i64 val2[100005][5];
i64 psum[3005][3005];
i64 psum2[100005][5];
int n;
// type == 0 desc, type == 1 desc top, type == 2 asc
i64 dp(int x, int y, int type)
{
if (x >= n + 1)
return 0;
if (table[x][y][type] != -1)
return table[x][y][type];
auto& res = table[x][y][type];
res = 0;
if (type == 2)
{
// 한 칸 더 위 or 오른쪽 위 or 두칸 뛰면서 desc
if (y < n)
{
res = max(res, val[x][y] + dp(x, y + 1, type));
if (x < n - 1)
res = max(res, val[x][y] + dp(x + 1, y + 1, type));
}
res = max({ res, val[x][y] + dp(x + 2, y, 1), val[x][y] + dp(x + 3, y, 1) });
}
else if (type == 1)
{
// 내 아래 다 먹고 여기부터 asc로 전환 가능
if (x <= n - 1)
{
auto down = y == 0 ? 0 : psum[x][y - 1];
res = dp(x, y, 2) + down;
}
// 그 외의 경우, 그냥 단순 내림차
res = max(res, dp(x, y, 0));
}
else
{
// 단순 내림차 -> 한칸 밑으로 or 오른쪽 한칸 밑으로
if (y == 0)
{
// 맨밑 왔으면 그냥 한칸 옆으로 가면 됨
res = dp(x + 1, y, 1);
return res;
}
if (x < n - 1 && y < n)
res = max(res, psum[x][y] + dp(x + 1, y + 1, 2));
res = val[x][y] + max({ dp(x + 1, y - 1, 1), dp(x, y - 1, 0), psum[x][y - 1] + dp(x + 2, y, 1), psum[x][y-1] + dp(x + 3, y, 1) });
}
return res;
}
i64 sub1(vector<int> w)
{
i64 res = 0;
for (auto& wi : w)
res += wi;
return res;
}
i64 sub2(int m, vector<int> x, vector<int> y, vector<int> w)
{
i64 res1 = 0, res2 = 0;
for (int i = 0; i < m; i++)
{
if (x[i] == 0)
res1 += w[i];
else
res2 += w[i];
}
if (n <= 2)
return max(res1, res2);
i64 sums[2][100005] = { 0, };
for (int i = 0; i < m; i++)
sums[x[i]][y[i]] += w[i];
for (int xi = 0; xi < 2; xi++)
for (int yi = 1; yi < n; yi++)
sums[xi][yi] += sums[xi][yi - 1];
i64 res = max(res1, res2);
for (int y = 0; y < n; y++)
res = max(res, sums[0][y] + res2 - sums[1][y]);
return res;
}
i64 dp2(int x, int y, int type)
{
if (x >= n + 1)
return 0;
if (table2[x][y][type] != -1)
return table2[x][y][type];
auto& res = table2[x][y][type];
res = 0;
if (type == 2)
{
// 한 칸 더 위 or 오른쪽 위 or 두칸 뛰면서 desc
if (y < 2)
{
res = max(res, val2[x][y] + dp2(x, y + 1, type));
if (x < n - 1)
res = max(res, val2[x][y] + dp2(x + 1, y + 1, type));
}
res = max({ res, val2[x][y] + dp2(x + 2, y, 1), val2[x][y] + dp2(x + 3, y, 1) });
}
else if (type == 1)
{
// 내 아래 다 먹고 여기부터 asc로 전환 가능
if (x <= n - 1)
{
auto down = y == 0 ? 0 : psum2[x][y - 1];
res = dp2(x, y, 2) + down;
}
// 그 외의 경우, 그냥 단순 내림차
res = max(res, dp2(x, y, 0));
}
else
{
// 단순 내림차 -> 한칸 밑으로 or 오른쪽 한칸 밑으로
if (y == 0)
{
// 맨밑 왔으면 그냥 한칸 옆으로 가면 됨
res = dp2(x + 1, y, 1);
return res;
}
if (x < n - 1 && y < 2)
res = max(res, psum2[x][y] + dp2(x + 1, y + 1, 2));
res = val2[x][y] + max({ dp2(x + 1, y - 1, 1), dp2(x, y - 1, 0), psum2[x][y - 1] + dp2(x + 2, y, 1), psum2[x][y - 1] + dp2(x + 3, y, 1) });
}
return res;
}
i64 sub3(int m, vector<int> x, vector<int> y, vector<int> w)
{
for (int i = 0; i < m; i++)
val2[x[i] + 1][y[i] + 1] += w[i];
for (int x = 1; x <= n; x++)
for (int y = 1; y <= 2; y++)
psum2[x][y] += psum2[x][y - 1] + val2[x][y];
memset(table2, -1, sizeof(table2));
return dp2(0, 0, 1);
}
i64 max_weights(int n_, int m, vector<int> x, vector<int> y, vector<int> w)
{
n = n_;
if (all_of(all(x), [](int xi) { return xi % 2 == 0; }))
return sub1(w);
if (all_of(all(x), [](int xi) { return xi <= 1; }))
return sub2(m, x, y, w);
if (all_of(all(y), [](int yi) { return yi == 0; }))
return sub3(m, x, y, w);
for (int i = 0; i < m; i++)
val[x[i] + 1][y[i] + 1] += w[i];
for (int x = 1; x <= n; x++)
for (int y = 1; y <= n; y++)
psum[x][y] += psum[x][y - 1] + val[x][y];
memset(table, -1, sizeof(table));
return dp(0, 0, 1);
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |