#include "wombats.h"
#include<bits/stdc++.h>
using namespace std;
int R, C;
vector<vector<int>> d, l, r;
vector<vector<vector<int>>> st;
vector<vector<int>> opt, dum1, dum2;
void calc(vector<vector<int>>& res, int tl, int tr){
//cout << tl << " " << tr << endl;
vector<int> ps1(C, 0), ps2(C, 0);
for (int i=0; i<C-1; i++){
ps1[i+1] = ps1[i]+r[tl][i];
ps2[i+1] = ps2[i]+r[tr][i];
}
for (int i=0; i<C; i++){
res[i][i] = 1<<25;
for (int j=0; j<C; j++){
int k = abs(ps1[j]-ps1[i])+abs(ps2[j]-ps2[i])+d[tl][j];
if (k < res[i][i]){
opt[i][i] = j;
res[i][i] = k;
}
}
}
for (int dif=1; dif<C; dif++){
for (int i=0; i+dif<C; i++){
res[i][i+dif] = 1<<25;
for (int j=opt[i][i+dif-1]; j<=opt[i+1][i+dif]; j++){
int k = abs(ps1[i]-ps1[j])+abs(ps2[i+dif]-ps2[j])+d[tl][j];
if (k < res[i][i+dif]){
opt[i][i+dif] = j;
res[i][i+dif] = k;
}
}
res[i+dif][i] = 1<<25;
for (int j=opt[i+dif-1][i]; j<=opt[i+dif][i+1]; j++){
int k = abs(ps1[i+dif]-ps1[j])+abs(ps2[i]-ps2[j])+d[tl][j];
if (k < res[i+dif][i]){
opt[i+dif][i] = j;
res[i+dif][i] = k;
}
}
}
}
/*for (int i=0; i<C; i++){
for (int j=0; j<C; j++) cout << res[i][j] << " ";
cout << endl;
}*/
}
void merge(vector<vector<int>>& res, vector<vector<int>>& l, vector<vector<int>>& r){
for (int i=0; i<C; i++){
res[i][i] = 1<<25;
for (int j=0; j<C; j++){
if (l[i][j]+r[j][i] < res[i][i]){
opt[i][i] = j;
res[i][i] = l[i][j]+r[j][i];
}
}
}
for (int dif=1; dif<C; dif++){
for (int i=0; i+dif<C; i++){
res[i][i+dif] = 1<<25;
for (int j=opt[i][i+dif-1]; j<=opt[i+1][i+dif]; j++){
if (l[i][j]+r[j][i+dif] < res[i][i+dif]){
opt[i][i+dif] = j;
res[i][i+dif] = l[i][j]+r[j][i+dif];
}
}
res[i+dif][i] = 1<<25;
for (int j=opt[i+dif-1][i]; j<=opt[i+dif][i+1]; j++){
if (l[i+dif][j]+r[j][i] < res[i+dif][i]){
opt[i+dif][i] = j;
res[i+dif][i] = l[i+dif][j]+r[j][i];
}
}
}
}
}
void compute(int node, int tl, int tr){
calc(st[node], tl, tl+1);
for (int i=tl+1; i<tr; i++){
dum1 = st[node];
calc(dum2, i, i+1);
merge(st[node], dum1, dum2);
}
}
void update(int pos, int node=1, int tl=0, int tr=R-1){
if (tr-tl <= 10){
compute(node, tl, tr);
return;
}
int tm = (tl+tr)/2;
if (pos <= tm) update(pos, node*2, tl, tm);
if (tm <= pos) update(pos, node*2+1, tm, tr);
merge(st[node], st[node*2], st[node*2+1]);
}
void build(int node=1, int tl=0, int tr=R-1){
if (tr-tl <= 10){
compute(node, tl, tr);
return;
}
int tm = (tl+tr)/2;
build(node*2, tl, tm);
build(node*2+1, tm, tr);
merge(st[node], st[node*2], st[node*2+1]);
}
void init(int R, int C, int H[5000][200], int V[5000][200]) {
::R = R; ::C = C;
d.resize(R, vector<int>(C));
l.resize(R, vector<int>(C));
r.resize(R, vector<int>(C));
for (int i=0; i<R; i++){
for (int j=0; j<C; j++){
if (i < R-1) d[i][j] = V[i][j];
if (j > 0) l[i][j] = H[i][j-1];
if (j < C-1) r[i][j] = H[i][j];
}
}
st.resize((1<<14)+5, vector<vector<int>>(C, vector<int>(C)));
opt.resize(C, vector<int>(C));
dum1.resize(C, vector<int>(C));
dum2.resize(C, vector<int>(C));
build();
}
void changeH(int P, int Q, int W) {
r[P][Q] = W;
l[P][Q+1] = W;
update(P);
}
void changeV(int P, int Q, int W) {
d[P][Q] = W;
update(P);
}
int escape(int V1, int V2) {
return st[1][V1][V2];
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |