제출 #628059

#제출 시각아이디문제언어결과실행 시간메모리
628059jwvg0425메기 농장 (IOI22_fish)C++17
70 / 100
810 ms361640 KiB
#include "fish.h"
#include <stdio.h>
#include <vector>
#include <queue>
#include <algorithm>
#include <iostream>
#include <string>
#include <bitset>
#include <map>
#include <set>
#include <tuple>
#include <string.h>
#include <math.h>
#include <random>
#include <functional>
#include <assert.h>
#include <math.h>
#define all(x) (x).begin(), (x).end()
#define xx first
#define yy second

using namespace std;

template<typename T, typename Pr = less<T>>
using pq = priority_queue<T, vector<T>, Pr>;
using i64 = long long int;
using ii = pair<int, int>;
using ii64 = pair<i64, i64>;

i64 table[3005][3005][3];
i64 table2[100005][5][3];
i64 val[3005][3005];
i64 val2[100005][5];
i64 psum[3005][3005];
i64 psum2[100005][5];
int n;

// type == 0 desc, type == 1 desc top, type == 2 asc
i64 dp(int x, int y, int type)
{
	if (x >= n + 1)
		return 0;

	if (table[x][y][type] != -1)
		return table[x][y][type];

	auto& res = table[x][y][type];
	res = 0;

	if (type == 2)
	{
		// 한 칸 더 위 or 오른쪽 위 or 두칸 뛰면서 desc
		if (y < n)
		{
			res = max(res, val[x][y] + dp(x, y + 1, type));
			if (x < n - 1)
				res = max(res, val[x][y] + dp(x + 1, y + 1, type));
		}

		res = max({ res, val[x][y] + dp(x + 2, y, 1), val[x][y] + dp(x + 3, y, 1) });
	}
	else if (type == 1)
	{
		// 내 아래 다 먹고 여기부터 asc로 전환 가능
		if (x <= n - 1)
		{
			auto down = y == 0 ? 0 : psum[x][y - 1];
			res = dp(x, y, 2) + down;
		}

		// 그 외의 경우, 그냥 단순 내림차
		res = max(res, dp(x, y, 0));
	}
	else
	{
		// 단순 내림차 -> 한칸 밑으로 or 오른쪽 한칸 밑으로
		if (y == 0)
		{
			// 맨밑 왔으면 그냥 한칸 옆으로 가면 됨
			res = dp(x + 1, y, 1);
			return res;
		}

		if (x < n - 1 && y < n)
			res = max(res, psum[x][y] + dp(x + 1, y + 1, 2));

		res = val[x][y] + max({ dp(x + 1, y - 1, 1), dp(x, y - 1, 0), psum[x][y - 1] + dp(x + 2, y, 1), psum[x][y-1] + dp(x + 3, y, 1) });
	}

	return res;
}

i64 sub1(vector<int> w)
{
	i64 res = 0;
	for (auto& wi : w)
		res += wi;

	return res;
}

i64 sub2(int m, vector<int> x, vector<int> y, vector<int> w)
{

	i64 res1 = 0, res2 = 0;

	for (int i = 0; i < m; i++)
	{
		if (x[i] == 0)
			res1 += w[i];
		else
			res2 += w[i];
	}

	if (n <= 2)
		return max(res1, res2);

	i64 sums[2][100005] = { 0, };

	for (int i = 0; i < m; i++)
		sums[x[i]][y[i]] += w[i];
	
	for (int xi = 0; xi < 2; xi++)
		for (int yi = 1; yi < n; yi++)
			sums[xi][yi] += sums[xi][yi - 1];

	i64 res = max(res1, res2);

	for (int y = 0; y < n; y++)
		res = max(res, sums[0][y] + res2 - sums[1][y]);

	return res;
}

i64 dp2(int x, int y, int type)
{
	if (x >= n + 1)
		return 0;

	if (table2[x][y][type] != -1)
		return table2[x][y][type];

	auto& res = table2[x][y][type];
	res = 0;

	if (type == 2)
	{
		// 한 칸 더 위 or 오른쪽 위 or 두칸 뛰면서 desc
		if (y < 2)
		{
			res = max(res, val2[x][y] + dp2(x, y + 1, type));
			if (x < n - 1)
				res = max(res, val2[x][y] + dp2(x + 1, y + 1, type));
		}

		res = max({ res, val2[x][y] + dp2(x + 2, y, 1), val2[x][y] + dp2(x + 3, y, 1) });
	}
	else if (type == 1)
	{
		// 내 아래 다 먹고 여기부터 asc로 전환 가능
		if (x <= n - 1)
		{
			auto down = y == 0 ? 0 : psum2[x][y - 1];
			res = dp2(x, y, 2) + down;
		}

		// 그 외의 경우, 그냥 단순 내림차
		res = max(res, dp2(x, y, 0));
	}
	else
	{
		// 단순 내림차 -> 한칸 밑으로 or 오른쪽 한칸 밑으로
		if (y == 0)
		{
			// 맨밑 왔으면 그냥 한칸 옆으로 가면 됨
			res = dp2(x + 1, y, 1);
			return res;
		}

		if (x < n - 1 && y < 2)
			res = max(res, psum2[x][y] + dp2(x + 1, y + 1, 2));

		res = val2[x][y] + max({ dp2(x + 1, y - 1, 1), dp2(x, y - 1, 0), psum2[x][y - 1] + dp2(x + 2, y, 1), psum2[x][y - 1] + dp2(x + 3, y, 1) });
	}

	return res;
}

i64 sub3(int m, vector<int> x, vector<int> y, vector<int> w)
{
	for (int i = 0; i < m; i++)
		val2[x[i] + 1][y[i] + 1] += w[i];

	for (int x = 1; x <= n; x++)
		for (int y = 1; y <= 2; y++)
			psum2[x][y] += psum2[x][y - 1] + val2[x][y];

	memset(table2, -1, sizeof(table2));

	return dp2(0, 0, 1);
}

i64 max_weights(int n_, int m, vector<int> x, vector<int> y, vector<int> w)
{
	n = n_;
	if (all_of(all(x), [](int xi) { return xi % 2 == 0; }))
		return sub1(w);

	if (all_of(all(x), [](int xi) { return xi <= 1; }))
		return sub2(m, x, y, w);

	if (all_of(all(y), [](int yi) { return yi == 0; }))
		return sub3(m, x, y, w);

	for (int i = 0; i < m; i++)
		val[x[i] + 1][y[i] + 1] += w[i];

	for (int x = 1; x <= n; x++)
		for (int y = 1; y <= n; y++)
			psum[x][y] += psum[x][y - 1] + val[x][y];

	memset(table, -1, sizeof(table));

	return dp(0, 0, 1);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...