#include <iostream>
#include <assert.h>
#include <bitset>
using namespace std;
const int INF = 1e9;
int S;
int t, n, k1, k2, c, dp[103][103][5001];
bitset<5001> par[103][103];
int ar[301][301], ans[301][301];
inline int getSum(int i1, int j1, int i2, int j2)
{
return ar[i2][j2] - ar[i1 - 1][j2] - ar[i2][j1 - 1] + ar[i1 - 1][j1 - 1];
}
int main()
{
cin >> t >> n >> k1 >> k2;
if (n <= 50) {S = 1;}
else {S = 3;}
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
int val; cin >> val;
ar[i / S + 1][j / S + 1] += val;
}
}
int ns = n / S;
for (int i = 1; i <= ns + 1; i++)
{
for (int j = 1; j <= ns + 1; j++)
{
ar[i][j] += ar[i - 1][j];
ar[i][j] += ar[i][j - 1];
ar[i][j] -= ar[i - 1][j - 1];
}
}
for (int i = 1; i <= ns; i++)
{
for (int j = 1; j <= ns + 1; j++)
{
for (int k = 0; 2 * k <= ns * ns; k++)
{
dp[i][j][k] = -INF;
}
}
}
dp[1][ns + 1][0] = 0;
for (int i = 1; i <= ns; i++)
{
for (int j = ns + 1; j > 0; j--)
{
for (int k = (ns + 1 - i) * (ns + 1 - j); 2 * k <= ns * ns; k++)
{
if (dp[i][j][k] == -INF)
{
if (j <= ns && k >= ns - i + 1)
{
if (dp[i][j + 1][k - (ns - i + 1)] > -INF && dp[i][j + 1][k - (ns - i + 1)] + getSum(i, j, ns + 1, j) > dp[i][j][k])
{
dp[i][j][k] = dp[i][j + 1][k - (ns - i + 1)] + getSum(i, j, ns + 1, j);
par[i][j][k] = 0;
}
}
if (i > 1)
{
if (dp[i - 1][j][k] > -INF && dp[i - 1][j][k] > dp[i][j][k])
{
dp[i][j][k] = dp[i - 1][j][k];
par[i][j][k] = 1;
}
}
//cout << i << " " << j << " " << k << " " << dp[i][j][k] << "\n";
}
}
}
}
int pi = ns, pj = 1, pk = ns * ns / 2, cnt = 0;
//cout << "Best : " << dp[pi][pj][pk] << "\n";
for (int j = 2; j <= ns + 1; j++)
{
if (dp[pi][j][pk] > dp[pi][pj][pk]) pj = j;
}
while (pj <= ns)
{
//cout << pi << " " << pj << " " << pk << " " << dp[pi][pj][pk] << " " << par[pi][pj][pk] << "\n";
if (par[pi][pj][pk]) {pi--;}
else
{
for (int i = (pi - 1) * S; i < n; i++)
{
for (int j = (pj - 1) * S; j < pj * S; j++)
{
ans[i][j] = 1;
cnt += ans[i][j];
}
}
pj++; pk -= (ns - pi + 1);
}
}
assert(cnt == n * n / 2);
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
assert(0 <= ans[i][j] && ans[i][j] <= 1);
cout << ans[i][j];
if (j + 1 < n) cout << " ";
}
cout << "\n";
}
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
896 KB |
K = 17 |
2 |
Correct |
23 ms |
24832 KB |
K = 576 |
3 |
Correct |
321 ms |
204408 KB |
K = 18637 |
4 |
Correct |
296 ms |
204664 KB |
K = 21814 |
5 |
Correct |
295 ms |
204536 KB |
K = 17127 |
6 |
Correct |
318 ms |
204536 KB |
K = 20861 |
7 |
Correct |
333 ms |
204536 KB |
K = 21447 |
8 |
Correct |
297 ms |
204540 KB |
K = 19612 |
9 |
Correct |
295 ms |
204528 KB |
K = 20771 |
10 |
Correct |
304 ms |
204376 KB |
K = 20150 |