이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include "rect.h"
#include <bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define TASK ""
#define bit(x) (1LL << (x))
#define getbit(x, i) (((x) >> (i)) & 1)
#define ALL(x) (x).begin(), (x).end()
using namespace std;
template <typename T1, typename T2> bool mini(T1 &a, T2 b) {
if (a > b) {a = b; return true;} return false;
}
template <typename T1, typename T2> bool maxi(T1 &a, T2 b) {
if (a < b) {a = b; return true;} return false;
}
mt19937_64 rd(chrono::steady_clock::now().time_since_epoch().count());
int rand(int l, int r) {
return l + rd() % (r - l + 1);
}
const int N = 2505;
const int oo = 1e9;
const long long ooo = 1e18;
const int mod = 1e9 + 7; // 998244353;
const long double pi = acos(-1);
vector <pair <int, int>> sr[N][N], sc[N][N];
vector <tuple <int, int, int>> events[N];
vector <int> col[N][N];
vector <int> row[N];
int pos[N][N];
int limc[N][N];
int lim[N][N];
int a[N][N];
int bit[N];
int v[N];
int l[N];
int r[N];
int n,m;
void update(int pos, int val) {
for (; pos <= m; pos += pos & -pos)
bit[pos] += val;
}
int getsum(int pos) {
int res = 0;
for (; pos; pos -= pos & -pos)
res += bit[pos];
return res;
}
void findmax(int a[], int l[], int n) {
vector <int> s;
s.push_back(0);
a[0] = oo;
for (int i = 1; i <= n; i++) {
while (a[i] > a[s.back()])
s.pop_back();
l[i] = s.back();
s.push_back(i);
}
}
void addrow(int i, int u, int v) {
if (sr[u][v].size() && sr[u][v].back().se == i - 1)
sr[u][v].back().se = i;
else
sr[u][v].push_back(mp(i, i));
}
void addcol(int j, int u, int v) {
if (sc[u][v].size() && sc[u][v].back().se == j - 1)
sc[u][v].back().se = j;
else
sc[u][v].push_back(mp(j, j));
col[u][j].push_back(v);
}
void build() {
for (int i = 1; i <= n; i++) {
findmax(a[i], l, m);
reverse(a[i] + 1, a[i] + m + 1);
findmax(a[i], r, m);
reverse(a[i] + 1, a[i] + m + 1);
reverse(r + 1, r + m + 1);
for (int j = 1; j <= m; j++)
r[j] = (m - r[j] + 1);
for (int j = 1; j <= m; j++) {
if (l[j] > 0 && l[j] != j - 1) {
auto [u, v] = mp(l[j], j);
addrow(i, u, v);
}
if (r[j] <= m && r[j] != j + 1 && a[i][j] < a[i][r[j]]) {
auto [u, v] = mp(j, r[j]);
addrow(i, u, v);
}
}
}
for (int j = 1; j <= m; j++) {
for (int i = 1; i <= n; i++)
v[i] = a[i][j];
findmax(v, l, n);
reverse(v + 1, v + n + 1);
findmax(v, r, n);
reverse(r + 1, r + n + 1);
for (int i = 1; i <= n; i++)
r[i] = (n - r[i] + 1);
for (int i = 1; i <= n; i++) {
if (l[i] > 0 && l[i] != i - 1) {
auto [u, v] = mp(l[i], i);
addcol(j, u, v);
}
if (r[i] <= n && r[i] != i + 1 && a[i][j] < a[r[i]][j]) {
auto [u, v] = mp(i, r[i]);
addcol(j, u, v);
}
}
}
for (int l = 1; l <= m; l++)
for (int r = l + 1; r <= m; r++)
for (auto [u, v] : sr[l][r]) {
// for (int j = u; j <= v; j++)
// events[j].push_back({l, r});
events[u].push_back({l, r, v});
events[v + 1].push_back({-l, -r, -v});
}
}
int solve() {
build();
int res = 0;
vector <pair <int, int>> cand;
for (int i = 1; i < n; i++) {
for (auto [l, r, v] : events[i]) {
if (l > 0) {
pos[l][r] = cand.size();
lim[l][r] = v;
cand.push_back({l, r});
} else {
l = -l, r = -r;
pos[cand.back().fi][cand.back().se] = pos[l][r];
swap(cand[pos[l][r]], cand.back());
cand.pop_back();
}
}
if (i == 1)
continue;
for (int j = 1; j <= m; j++)
row[j].clear();
for (auto [l, r] : cand)
row[l].push_back(r);
for (int l = 1; l <= m; l++) {
sort(ALL(row[l]), [&](int i, int j) {
return lim[l][i] < lim[l][j];
});
sort(ALL(col[i - 1][l + 1]));
vector <int> p;
const vector <int> &f = row[l];
const vector <int> &g = col[i - 1][l + 1];
for (int it = 0, jt = 0; it < (int) f.size(); it++) {
int r = f[it];
// cerr << "seg " << i << " " << l << " " << r << " " << lim[l][r] << "\n";
while (jt < (int) g.size()) {
int k = g[jt];
if (k <= lim[l][r] + 1) {
// cerr << "update " << i - 1 << " " << k << "\n";
pair <int, int> to_find = {l + 1, oo};
int tmp = prev(upper_bound(ALL(sc[i - 1][k]), to_find))->se;
update(tmp, 1);
p.push_back(tmp);
jt++;
continue;
}
break;
}
res += getsum(m) - getsum(r - 2);
// cerr << res << "\n";
}
for (int lim : p)
update(lim, -1);
}
}
return res;
}
long long count_rectangles(vector<vector<int>> _a) {
n = _a.size();
m = _a[0].size();
for (int i = 1; i <= n; i++)
for (int j = 1; j <= m; j++)
a[i][j] = _a[i - 1][j - 1];
return solve();
}
//#include "grader.cpp"
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |