#include "rect.h"
#include <bits/stdc++.h>
using namespace std;
int n, m;
vector<vector<int>> a;
namespace sub4{
const int N=700;
bitset<N> b[N][N], bs[N], c[N][N], msk[N][N];
bool check(){
return n<=700 && m<=700;
}
long long solve(){
for (int i=0; i<m; ++i) for (int j=i; j<m; ++j){
msk[i][j]=msk[i][j-1];
msk[i][j].set(j);
}
for (int i=0; i<n; ++i){
for (int l=0; l<m; ++l){
int mx=0;
for (int r=l+2; r<m; ++r){
mx=max(mx, a[i][r-1]);
if (a[i][l]>mx && a[i][r]>mx) b[i][l].set(r);
}
}
}
for (int i=0; i<m; ++i){
for (int l=0; l<n; ++l){
int mx=0;
for (int r=l+2; r<n; ++r){
mx=max(mx, a[r-1][i]);
if (a[l][i]>mx && a[r][i]>mx) c[i][l].set(r);
}
}
}
long long ans=0;
for (int i=0; i<n; ++i){
for (int j=0; j<m; ++j) bs[j].set();
for (int j=i+2; j<n; ++j){
for (int k=m-1, nxt=m-1; k>=0; --k){
bs[k]&=b[j-1][k];
if (k+2<=nxt) ans+=(bs[k]&msk[k+2][nxt]).count();
if (!c[k][i].test(j)) nxt=k;
}
}
}
return ans;
}
}
namespace sub5{
const int N=2500;
bitset<3> b[N][3], bs[3];
bitset<N> c[3][N];
bitset<3> msk[3][3];
bool check(){
return n<=3;
}
long long solve(){
vector<vector<int>> aa(m, vector<int>(n));
for (int i=0; i<n; ++i) for (int j=0; j<m; ++j) aa[j][i]=a[i][j];
swap(n, m);
aa.swap(a);
for (int i=0; i<m; ++i) for (int j=i; j<m; ++j){
msk[i][j]=msk[i][j-1];
msk[i][j].set(j);
}
for (int i=0; i<n; ++i){
for (int l=0; l<m; ++l){
int mx=0;
for (int r=l+2; r<m; ++r){
mx=max(mx, a[i][r-1]);
if (a[i][l]>mx && a[i][r]>mx) b[i][l].set(r);
}
}
}
for (int i=0; i<m; ++i){
for (int l=0; l<n; ++l){
int mx=0;
for (int r=l+2; r<n; ++r){
mx=max(mx, a[r-1][i]);
if (a[l][i]>mx && a[r][i]>mx) c[i][l].set(r);
}
}
}
long long ans=0;
for (int i=0; i<n; ++i){
for (int j=0; j<m; ++j) bs[j].set();
for (int j=i+2; j<n; ++j){
for (int k=m-1, nxt=m-1; k>=0; --k){
bs[k]&=b[j-1][k];
if (k+2<=nxt) ans+=(bs[k]&msk[k+2][nxt]).count();
if (!c[k][i].test(j)) nxt=k;
}
}
}
return ans;
}
}
namespace sub6{
bool check(){
for (int i=0; i<n; ++i) for (int j=0; j<m; ++j) if (a[i][j]>1) return 0;
return 1;
}
vector<vector<int>> h, pf;
int sum(int x, int y, int z, int t){
int ans=pf[z][t];
if (x) ans-=pf[x-1][t];
if (y) ans-=pf[z][y-1];
if (x && y) ans+=pf[x-1][y-1];
return ans;
}
long long solve(){
pf=a;
for (int i=0; i<n; ++i) for (int j=0; j<m; ++j){
pf[i][j]=a[i][j];
if (i) pf[i][j]+=pf[i-1][j];
if (j) pf[i][j]+=pf[i][j-1];
if (i && j) pf[i][j]-=pf[i-1][j-1];
}
h=a;
for (int i=0; i<n; ++i){
for (int j=0; j<m; ++j){
if (a[i][j]==1) h[i][j]=0;
else h[i][j]=i==0?1:h[i-1][j]+1;
}
}
long long ans=0;
for (int i=0; i<n-1; ++i){
for (int j=0; j<m-2; ++j) if (a[i][j]==1 && a[i][j+1]==0){
int k=j+1;
int maxh=h[i][k], minh=h[i][k];
while (k<m && a[i][k]==0){
maxh=max(maxh, h[i][k]);
minh=min(minh, h[i][k]);
++k;
}
if (k<m){
if (minh==maxh && maxh<=i
&& sum(i-maxh+1, j+1, i, k-1)==0
&& sum(i-maxh+1, j, i, j)==maxh && sum(i-maxh+1, k, i, k)==maxh
&& sum(i-maxh, j+1, i-maxh, k-1)==(k-j-1) && sum(i+1, j+1, i+1, k-1)==(k-j-1)){
++ans;
}
}
j=k-1;
}
}
return ans;
}
}
long long count_rectangles(vector<vector<int>> _a){
a=_a;
n=a.size(), m=a[0].size();
if (sub4::check()) return sub4::solve();
if (sub5::check()) return sub5::solve();
if (sub6::check()) return sub6::solve();
return 0;
}