#include "rect.h"
#include <bits/stdc++.h>
using namespace std;
const int N=700;
bitset<N> b[N][N], bs[N], c[N][N], msk[N][N];
long long count_rectangles(vector<vector<int>> a){
int n=a.size(), m=a[0].size();
for (int i=0; i<n; ++i) for (int j=i; j<n; ++j){
msk[i][j]=msk[i][j-1];
msk[i][j].set(j);
}
for (int i=0; i<n; ++i){
for (int l=0; l<m; ++l){
int mx=0;
for (int r=l+2; r<m; ++r){
mx=max(mx, a[i][r-1]);
if (a[i][l]>mx && a[i][r]>mx) b[i][l].set(r);
}
}
}
for (int i=0; i<m; ++i){
for (int l=0; l<n; ++l){
int mx=0;
for (int r=l+2; r<n; ++r){
mx=max(mx, a[r-1][i]);
if (a[l][i]>mx && a[r][i]>mx) c[i][l].set(r);
}
}
}
long long ans=0;
for (int i=0; i<n; ++i){
for (int j=0; j<m; ++j) bs[j].set();
for (int j=i+2; j<n; ++j){
for (int k=m-1, nxt=m-1; k>=0; --k){
bs[k]&=b[j-1][k];
if (k+2<=nxt) ans+=(bs[k]&msk[k+2][nxt]).count();
if (!c[k][i].test(j)) nxt=k;
}
}
}
return ans;
}