#include "rect.h"
#include <bits/stdc++.h>
#define rep(i,a,b) for(ll i=(a);i<=(b);++i)
#define per(i,a,b) for(ll i=(b);i>=(a);--i)
#define siz(x) ((ll)(x).size())
#define all(x) (x).begin(), (x).end()
using namespace std; using ll = long long; using vi = vector<ll>; using ii = pair<ll,ll>;
const int N = 2503;
int a[N][N];
int n, m;
vector<short> row[N][N], col[N][N];
int L[N][N], R[N][N], U[N][N], D[N][N];
bool valid(int u, int d, int l, int r) {
if(u<=1||d>=n||l<=1||r>=m) return false;
int tu = lower_bound(all(row[l][r]), u) - row[l][r].begin();
int td = lower_bound(all(row[l][r]), d) - row[l][r].begin();
if (tu>=siz(row[l][r]) or td-tu != d-u or row[l][r][tu] != u or row[l][r][td] != d) return false;
int tl = lower_bound(all(col[u][d]), l) - col[u][d].begin();
int tr = lower_bound(all(col[u][d]), r) - col[u][d].begin();
if (tl>=siz(col[u][d]) or tr-tl != r-l or col[u][d][tl] != l or col[u][d][tr] != r) return false;
// cerr << u << ' ' << d << ' ' << l << ' ' << r << endl;
return true;
}
ll count_rectangles(vector<vector<int> > _a) {
n = siz(_a); m = _a[0].size();
rep(i,1,n) rep(j,1,m) a[i][j]=_a[i-1][j-1];
{
rep(i,1,n) {
stack<int> stk; stk.push(0);
a[i][0] = a[i][m+1] = 1e9;
rep(j,1,m) {
while(!empty(stk) && a[i][stk.top()] <= a[i][j]) stk.pop();
L[i][j] = stk.top()+1; stk.push(j);
}
while(!empty(stk))stk.pop(); stk.push(m+1);
per(j,1,m) {
while(!empty(stk) && a[i][stk.top()] <= a[i][j]) stk.pop();
R[i][j] = stk.top()-1; stk.push(j);
}
}
rep(j,1,m) {
stack<int> stk; stk.push(0);
a[0][j] = a[n+1][j] = 1e9;
rep(i,1,n) {
while(!empty(stk) && a[stk.top()][j] <= a[i][j]) stk.pop();
U[i][j] = stk.top()+1; stk.push(i);
}
while(!empty(stk))stk.pop(); stk.push(n+1);
per(i,1,n) {
while(!empty(stk) && a[stk.top()][j] <= a[i][j]) stk.pop();
D[i][j] = stk.top()-1; stk.push(i);
}
}
}
rep(i,2,n-1) {
rep(j,2,m-1) {
row[L[i][j]][R[i][j]].push_back(i);
}
}
rep(j,2,m-1) {
rep(i,2,n-1) {
col[U[i][j]][D[i][j]].push_back(j);
}
}
vector<ll> rects;
rep(i,2,n-1) {
rep(j,2,m-1) {
if (valid(U[i][j], D[i][j], L[i][j], R[i][j])) {
rects.push_back( U[i][j] + D[i][j] * (n+1ll) + L[i][j] * (n+1ll) * (n+1ll) + R[i][j] * (n+1ll) * (n+1ll) * (m+1ll) );
}
}
}
sort(all(rects)); rects.erase(unique(all(rects)), end(rects));
return siz(rects);
}