#include "rect.h"
#include <bits/stdc++.h>
using namespace std;
const int LG=12, N=2500;
int n, m;
vector<vector<int>> a;
pair<int, int> stmin[LG][N], stmax[LG][N];
vector<pair<pair<int, int>, int>> ev[N][N];
void gen_r(int mode){
map<pair<int, int>, int> mp, mp2;
for (int i=0; i<n; ++i){
for (int j=0; j<m; ++j) stmin[0][j]={a[i][j], -j}, stmax[0][j]={a[i][j], j};
for (int k=1; k<LG; ++k){
for (int j=0; j+(1<<k)-1<m; ++j){
stmin[k][j]=max(stmin[k-1][j], stmin[k-1][j+(1<<(k-1))]);
stmax[k][j]=max(stmax[k-1][j], stmax[k-1][j+(1<<(k-1))]);
}
}
vector<pair<int, int>> vv;
auto get=[&](pair<int, int> st[LG][N], int l, int r) -> int {
int lg=__lg(r-l+1);
return abs(max(st[lg][l], st[lg][r-(1<<lg)+1]).second);
};
auto dnc=[&](auto &&self, int l, int r) -> void {
if (l+1>=r) return;
int id=get(stmin, l, r);
int _r=r;
while (id+1<=r){
int id2=get(stmin, id+1, r);
if (id+1<id2) vv.emplace_back(id, id2);
r=id2-1;
}
r=_r;
int _l=l;
while (l<=id-1){
int id2=get(stmax, l, id-1);
if (id2+1<id) vv.emplace_back(id2, id);
l=id2+1;
}
l=_l;
self(self, l, id-1);
self(self, id+1, r);
};
dnc(dnc, 0, m-1);
for (auto &j:vv){
auto it=mp.find(j);
int L=j.first, R=j.second, U=i, D=i;
if (it!=mp.end()) U=it->second;
mp2[{L, R}]=U;
--U; ++D;
if (mode==0){
// cout << mode << ' ' << L << ' ' << R << ' ' << U << ' ' << D << endl;
if (D<n) ev[R][D].push_back({{U, 0}, L});
}else{
// cout << mode << ' ' << U << ' ' << D << ' ' << L << ' ' << R << endl;
if (D<n) ev[D][R].push_back({{L, 1}, U});
}
}
mp.swap(mp2);
mp2.clear();
}
}
struct BIT{
int n;
int t[N];
void init(int _n){
n=_n;
memset(t, 0, sizeof t);
}
void update(int pos, int val){
for (int i=pos; i<=n; i+=i&(-i)) t[i]+=val;
}
int get(int pos){
int ans=0;
for (int i=pos; i; i-=i&(-i)) ans+=t[i];
return ans;
}
} bit;
long long count_rectangles(vector<vector<int>> _a){
a=_a;
n=a.size(), m=a[0].size();
gen_r(0);
vector<vector<int>> aa(m, vector<int>(n));
for (int i=0; i<n; ++i) for (int j=0; j<m; ++j) aa[j][i]=a[i][j];
swap(n, m);
a.swap(aa);
gen_r(1);
long long ans=0;
swap(n, m);
for (int i=0; i<m; ++i) for (int j=0; j<n; ++j){
sort(ev[i][j].begin(), ev[i][j].end());
if (ev[i][j].size()){
bit.init(m);
for (auto &x:ev[i][j]){
if (x.first.second==0){
bit.update(m-x.second, 1);
}else{
ans+=bit.get(m-x.second);
}
}
// cout << ans << endl;
}
}
return ans;
}