#include <bits/stdc++.h>
#include "rect.h"
using namespace std;
long long count_rectangles(vector<vector<int>> a) {
int n = a.size();
int m = a[0].size();
vector<vector<int>> p(n + 1, vector<int>(m + 1));
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
p[i + 1][j + 1] = p[i][j + 1] + p[i + 1][j] - p[i][j] + a[i][j];
}
}
auto get = [&](int x1, int y1, int x2, int y2) {
x1++; y1++; x2++; y2++;
return p[x2][y2] - p[x2][y1 - 1] - p[x1 - 1][y2] + p[x1 - 1][y1 - 1];
};
long long ans = 0;
for (int i = 1; i < n - 1; i++) {
for (int j = 1; j < m - 1; j++) {
if (a[i][j] == 1) continue;
int l = j, r = m - 1;
while (l + 1 < r) {
int mid = (l + r) / 2;
if (get(i, j, i, mid) == 0) {
l = mid;
} else {
r = mid;
}
}
int y = l;
l = i, r = n - 1;
while (l + 1 < r) {
int mid = (l + r) / 2;
if (get(i, j, mid, y) == 0) {
l = mid;
} else {
r = mid;
}
}
int x = l;
int w = y - j + 1, h = x - i + 1;
int w1 = get(i - 1, j, i - 1, y);
int w2 = get(x + 1, j, x + 1, y);
int h1 = get(i, j - 1, x, j - 1);
int h2 = get(i, y + 1, x, y + 1);
if (w == w1 && w == w2 && h == h1 && h == h2) {
ans++;
}
}
}
return ans;
}