#include "rect.h"
#include <bits/stdc++.h>
using namespace std;
struct rect {
int l1, c1, l2, c2;
bool operator < (const rect &x) const {
if (l1 == x.l1) {
if (c1 == x.c1) {
if (l2 == x.l2)
return c2 < x.c2;
return l2 < x.l2;
}
return c1 < x.c1;
}
return l1 < x.l1;
}
bool operator == (const rect &x) const {
if (x < *this)
return false;
if (*this < x)
return false;
return true;
}
};
const int MAX_N = 2500;
const int MAX_LOG_N = 12;
const int INF = 1e9;
int n, m;
long long ans = 0;
int mat[MAX_N + 2][MAX_N + 2];
int leftG[MAX_N + 2][MAX_N + 2], leftGE[MAX_N + 2][MAX_N + 2];
int rightG[MAX_N + 2][MAX_N + 2], rightGE[MAX_N + 2][MAX_N + 2];
int upG[MAX_N + 2][MAX_N + 2], upGE[MAX_N + 2][MAX_N + 2];
int downG[MAX_N + 2][MAX_N + 2], downGE[MAX_N + 2][MAX_N + 2];
int minRight[MAX_N + 2][MAX_N + 2][MAX_LOG_N + 1];
int maxLeft[MAX_N + 2][MAX_N + 2][MAX_LOG_N + 1];
int minDown[MAX_N + 2][MAX_N + 2][MAX_LOG_N + 1];
int maxUp[MAX_N + 2][MAX_N + 2][MAX_LOG_N + 1];
vector<int> l2s[MAX_N + 2][MAX_N + 2];
vector<int> c2s[MAX_N + 2][MAX_N + 2];
vector<int> st;
void init() {
for (int l = 1; l <= n; l++) {
mat[l][0] = INF;
st.clear();
st.push_back(0);
for (int c = 1; c <= m; c++) {
while (mat[l][c] > mat[l][st.back()])
st.pop_back();
leftGE[l][c] = st.back();
st.push_back(c);
}
st.clear();
st.push_back(0);
for (int c = 1; c <= m; c++) {
while (mat[l][c] >= mat[l][st.back()])
st.pop_back();
leftG[l][c] = st.back();
st.push_back(c);
}
mat[l][m + 1] = INF;
st.clear();
st.push_back(m + 1);
for (int c = m; c >= 1; c--) {
while (mat[l][c] > mat[l][st.back()])
st.pop_back();
rightGE[l][c] = st.back();
st.push_back(c);
}
st.clear();
st.push_back(m + 1);
for (int c = m; c >= 1; c--) {
while (mat[l][c] >= mat[l][st.back()])
st.pop_back();
rightG[l][c] = st.back();
st.push_back(c);
}
}
for (int c = 1; c <= m; c++) {
mat[0][c] = INF;
st.clear();
st.push_back(0);
for (int l = 1; l <= n; l++) {
while (mat[l][c] > mat[st.back()][c])
st.pop_back();
upGE[l][c] = st.back();
st.push_back(l);
}
st.clear();
st.push_back(0);
for (int l = 1; l <= n; l++) {
while (mat[l][c] >= mat[st.back()][c])
st.pop_back();
upG[l][c] = st.back();
st.push_back(l);
}
mat[n + 1][c] = INF;
st.clear();
st.push_back(n + 1);
for (int l = n; l >= 1; l--) {
while (mat[l][c] > mat[st.back()][c])
st.pop_back();
downGE[l][c] = st.back();
st.push_back(l);
}
st.clear();
st.push_back(n + 1);
for (int l = n; l >= 1; l--) {
while (mat[l][c] >= mat[st.back()][c])
st.pop_back();
downG[l][c] = st.back();
st.push_back(l);
}
}
}
int queryMinRight(int i, int l, int r) {
int p = log2(r - l + 1);
return min(minRight[i][l][p], minRight[i][r - (1 << p) + 1][p]);
}
int queryMaxLeft(int i, int l, int r) {
int p = log2(r - l + 1);
return max(maxLeft[i][l][p], maxLeft[i][r - (1 << p) + 1][p]);
}
int queryMinDown(int i, int l, int r) {
int p = log2(r - l + 1);
return min(minDown[i][l][p], minDown[i][r - (1 << p) + 1][p]);
}
int queryMaxUp(int i, int l, int r) {
int p = log2(r - l + 1);
return max(maxUp[i][l][p], maxUp[i][r - (1 << p) + 1][p]);
}
void solve(int l1, int c1, int l2, int c2) {
if (l1 < 1 || c1 < 1 || l2 > n || c2 > m || l2 - l1 <= 1 || c2 - c1 <= 1)
return;
bool ok = true;
int minRight = queryMinRight(c1, l1 + 1, l2 - 1);
int maxLeft = queryMaxLeft(c2, l1 + 1, l2 - 1);
int minDown = queryMinDown(l1, c1 + 1, c2 - 1);
int maxUp = queryMaxUp(l2, c1 + 1, c2 - 1);
if (minRight < c2)
ok = false;
if (maxLeft > c1)
ok = false;
if (minDown < l2)
ok = false;
if (maxUp > l1)
ok = false;
ok = true;
for (int l = l1 + 1; l <= l2 - 1 && ok; l++)
ok &= (rightGE[l][c1] >= c2 && (leftGE[l][c2] <= c1));
for (int c = c1 + 1; c <= c2 - 1 && ok; c++)
ok &= (downGE[l1][c] >= l2 && (upGE[l2][c] <= l1));
ans += ok;
}
long long count_rectangles(vector<vector<int>> a) {
n = a.size(), m = a[0].size();
for (int l = 1; l <= n; l++) {
for (int c = 1; c <= m; c++)
mat[l][c] = a[l - 1][c - 1];
}
init();
for (int l = 1; l <= n; l++) {
for (int c = 1; c <= m; c++) {
minRight[c][l][0] = rightGE[l][c];
maxLeft[c][l][0] = leftGE[l][c];
minDown[l][c][0] = downGE[l][c];
maxUp[l][c][0] = upGE[l][c];
}
}
for (int p = 1; (1 << p) <= max(n, m); p++) {
for (int i = 1; i <= m; i++) {
for (int j = 1; j <= n; j++)
minRight[i][j][p] = min(minRight[i][j][p - 1], minRight[i][j + (1 << (p - 1))][p - 1]);
for (int j = 1; j <= n; j++)
maxLeft[i][j][p] = min(maxLeft[i][j][p - 1], maxLeft[i][j + (1 << (p - 1))][p - 1]);
}
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++)
minDown[i][j][p] = min(minDown[i][j][p - 1], minDown[i][j + (1 << (p - 1))][p - 1]);
for (int j = 1; j <= m; j++)
maxUp[i][j][p] = max(maxUp[i][j][p - 1], maxUp[i][j + (1 << (p - 1))][p - 1]);
}
}
for (int l = 1; l <= n; l++) {
for (int c = 1; c <= m; c++) {
l2s[upGE[l][c]][c - 1].push_back(l);
c2s[l - 1][leftGE[l][c]].push_back(c);
}
}
for (int l = 1; l <= n; l++) {
for (int c = 1; c <= m; c++) {
l2s[l][c].push_back(downGE[l][c + 1]);
c2s[l][c].push_back(rightGE[l + 1][c]);
sort(l2s[l][c].begin(), l2s[l][c].end());
l2s[l][c].resize(unique(l2s[l][c].begin(), l2s[l][c].end()) - l2s[l][c].begin());
sort(c2s[l][c].begin(), c2s[l][c].end());
c2s[l][c].resize(unique(c2s[l][c].begin(), c2s[l][c].end()) - c2s[l][c].begin());
for (int l2: l2s[l][c]) {
for (int c2: c2s[l][c])
solve(l, c, l2, c2);
}
}
}
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |