#include "wombats.h"
#include<bits/stdc++.h>
using namespace std;
int R, C;
vector<vector<int>> d, l, r;
vector<vector<vector<int>>> st;
vector<vector<int>> opt;
void calc(int node, int tl, int tr){
//cout << tl << " " << tr << endl;
vector<int> ps1(C, 0), ps2(C, 0);
for (int i=0; i<C-1; i++){
ps1[i+1] = ps1[i]+r[tl][i];
ps2[i+1] = ps2[i]+r[tr][i];
}
for (int i=0; i<C; i++){
st[node][i][i] = 1<<25;
for (int j=0; j<C; j++){
int k = abs(ps1[j]-ps1[i])+abs(ps2[j]-ps2[i])+d[tl][j];
if (k < st[node][i][i]){
opt[i][i] = j;
st[node][i][i] = k;
}
}
}
for (int dif=1; dif<C; dif++){
for (int i=0; i+dif<C; i++){
st[node][i][i+dif] = 1<<25;
for (int j=opt[i][i+dif-1]; j<=opt[i+1][i+dif]; j++){
int k = abs(ps1[i]-ps1[j])+abs(ps2[i+dif]-ps2[j])+d[tl][j];
if (k < st[node][i][i+dif]){
opt[i][i+dif] = j;
st[node][i][i+dif] = k;
}
}
st[node][i+dif][i] = 1<<25;
for (int j=opt[i+dif-1][i]; j<=opt[i+dif][i+1]; j++){
int k = abs(ps1[i+dif]-ps1[j])+abs(ps2[i]-ps2[j])+d[tl][j];
if (k < st[node][i+dif][i]){
opt[i+dif][i] = j;
st[node][i+dif][i] = k;
}
}
}
}
/*for (int i=0; i<C; i++){
for (int j=0; j<C; j++) cout << st[node][i][j] << " ";
cout << endl;
}*/
}
void merge(int node){
int l = node*2, r = node*2+1;
for (int i=0; i<C; i++){
st[node][i][i] = 1<<25;
for (int j=0; j<C; j++){
if (st[l][i][j]+st[r][j][i] < st[node][i][i]){
opt[i][i] = j;
st[node][i][i] = st[l][i][j]+st[r][j][i];
}
}
}
for (int dif=1; dif<C; dif++){
for (int i=0; i+dif<C; i++){
st[node][i][i+dif] = 1<<25;
for (int j=opt[i][i+dif-1]; j<=opt[i+1][i+dif]; j++){
if (st[l][i][j]+st[r][j][i+dif] < st[node][i][i+dif]){
opt[i][i+dif] = j;
st[node][i][i+dif] = st[l][i][j]+st[r][j][i+dif];
}
}
st[node][i+dif][i] = 1<<25;
for (int j=opt[i+dif-1][i]; j<=opt[i+dif][i+1]; j++){
if (st[l][i+dif][j]+st[r][j][i] < st[node][i+dif][i]){
opt[i+dif][i] = j;
st[node][i+dif][i] = st[l][i+dif][j]+st[r][j][i];
}
}
}
}
}
void update(int pos, int node=1, int tl=0, int tr=R-1){
if (tl+1 == tr){
calc(node, tl, tr);
return;
}
int tm = (tl+tr)/2;
if (pos <= tm) update(pos, node*2, tl, tm);
if (tm <= pos) update(pos, node*2+1, tm, tr);
merge(node);
}
void build(int node=1, int tl=0, int tr=R-1){
if (tl+1 == tr){
calc(node, tl, tr);
return;
}
int tm = (tl+tr)/2;
build(node*2, tl, tm);
build(node*2+1, tm, tr);
merge(node);
}
void init(int R, int C, int H[5000][200], int V[5000][200]) {
::R = R; ::C = C;
d.resize(R, vector<int>(C));
l.resize(R, vector<int>(C));
r.resize(R, vector<int>(C));
for (int i=0; i<R; i++){
for (int j=0; j<C; j++){
if (i < R-1) d[i][j] = V[i][j];
if (j > 0) l[i][j] = H[i][j-1];
if (j < C-1) r[i][j] = H[i][j];
}
}
st.resize(1<<14+5, vector<vector<int>>(C, vector<int>(C)));
opt.resize(C, vector<int>(C));
build();
}
void changeH(int P, int Q, int W) {
r[P][Q] = W;
l[P][Q+1] = W;
update(P);
}
void changeV(int P, int Q, int W) {
d[P][Q] = W;
update(P);
}
int escape(int V1, int V2) {
return st[1][V1][V2];
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |