Submission #417968

#TimeUsernameProblemLanguageResultExecution timeMemory
417968usachevd0Rectangles (IOI19_rect)C++17
100 / 100
3345 ms953296 KiB
// #pragma gcc optimize("Ofast")
// #pragma gcc optimize("O3")
// #pragma gcc optimize("fast-math")
// #pragma gcc optimize("no-stack-protector")
// #pragma gcc optimize("unroll-loops")
// #pragma gcc target("avx,avx2,fma,sse,sse2,sse3,sse4,sse5,popcnt")
#include <bits/stdc++.h>
#ifndef LOCAL
    #include "rect.h"
#endif

using namespace std;

#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define all(a) (a).begin(), (a).end()
#define Time (clock() * 1.0 / CLOCKS_PER_SEC)
using ll = long long;
using ull = unsigned long long;
using pii = pair<int, int>;
using pil = pair<int, ll>;
using pli = pair<ll, int>;
using pll = pair<ll, ll>;
using ld = long double;
template<typename T1, typename T2> bool chkmin(T1& x, T2 y) {
    return y < x ? (x = y, true) : false;
}
template<typename T1, typename T2> bool chkmax(T1& x, T2 y) {
    return y > x ? (x = y, true) : false;
}
void debug_out() {
    cerr << endl;
}
template<typename T1, typename... T2> void debug_out(T1 A, T2... B) {
    cerr << ' ' << A;
    debug_out(B...);
}
template<typename T> void mdebug_out(T* a, int n) {
    for (int i = 0; i < n; ++i)
        cerr << a[i] << ' ';
    cerr << endl;
}
#ifdef LOCAL
    #define debug(...) cerr << "[" << #__VA_ARGS__ << "]:", debug_out(__VA_ARGS__)
    #define mdebug(a, n) cerr << #a << ": ", mdebug_out(a, n)
#else
    #define debug(...) 1337
    #define mdebug(a, n) 1337
#endif
template<typename T> ostream& operator << (ostream& stream, const vector<T>& v) {
    for (auto& x : v)
        stream << x << ' ';
    return stream;
}
template<typename T1, typename T2> ostream& operator << (ostream& stream, const pair<T1, T2>& p) {
    return stream << p.first << ' ' << p.second;
}

const int INF32 = 1e9;
#ifdef LOCAL
    const int maxN = 30;
#else
    const int maxN = 2503;
#endif
int n, m;
int a[maxN][maxN];
int U[maxN][maxN], D[maxN][maxN], L[maxN][maxN], R[maxN][maxN];
vector<int> goodH[maxN][maxN], goodV[maxN][maxN];
vector<int> goodVD[maxN][maxN];
int ptrVD[maxN][maxN];

bool good[maxN];
int bh[maxN], bhc;
int bhl[maxN], bhr[maxN];

void precUDLR() {
    static int stk[maxN], sz;
    for (int y = 1; y <= m; ++y) {
        stk[0] = 0;
        sz = 1;
        for (int x = 1; x <= n; ++x) {
            while (a[stk[sz - 1]][y] < a[x][y])
                --sz;
            U[x][y] = stk[sz - 1];
            stk[sz++] = x;
        }
        stk[0] = n + 1;
        sz = 1;
        for (int x = n; x >= 1; --x) {
            while (a[stk[sz - 1]][y] < a[x][y])
                --sz;
            D[x][y] = stk[sz - 1];
            stk[sz++] = x;
        }
    }
    for (int x = 1; x <= n; ++x) {
        stk[0] = 0;
        sz = 1;
        for (int y = 1; y <= m; ++y) {
            while (a[x][stk[sz - 1]] < a[x][y])
                --sz;
            L[x][y] = stk[sz - 1];
            stk[sz++] = y;
        }
        stk[0] = m + 1;
        sz = 1;
        for (int y = m; y >= 1; --y) {
            while (a[x][stk[sz - 1]] < a[x][y])
                --sz;
            R[x][y] = stk[sz - 1];
            stk[sz++] = y;
        }
    }
}

ll count_rectangles(vector<vector<int>> _a) {
    n = _a.size(), m = _a[0].size();
    memset(a, 0x3f, sizeof a);
    for (int i = 0; i < n; ++i)
        for (int j = 0; j < m; ++j)
            a[i + 1][j + 1] = _a[i][j];
    precUDLR();
    for (int y = 1; y <= m; ++y)
        for (int x = 1; x <= n; ++x) {
            if (x + 2 <= D[x][y] && D[x][y] <= n) {
                goodH[x + 1][D[x][y] - 1].push_back(y);
                // debug("H", x, D[x][y], y);
            }
            if (1 <= U[x][y] && U[x][y] <= x - 2) {
                goodH[U[x][y] + 1][x - 1].push_back(y);
                // debug("H", U[x][y], x, y);
            }
        }
    for (int x = 1; x <= n; ++x) {
        for (int y = 1; y <= m; ++y) {
            if (y + 2 <= R[x][y] && R[x][y] <= m) {
                goodV[y + 1][R[x][y] - 1].push_back(x);
                // debug("V", y, R[x][y], x);
            }
            if (1 <= L[x][y] && L[x][y] <= y - 2) {
                goodV[L[x][y] + 1][y - 1].push_back(x);
                // debug("V", L[x][y], y, x);
            }
        }
    }
    for (int y1 = 2; y1 <= m - 1; ++y1)
        for (int y2 = y1; y2 <= m - 1; ++y2) {
            auto& g = goodV[y1][y2];
            if (g.empty()) continue;
            g.resize(unique(all(g)) - g.begin());
            auto& d = goodVD[y1][y2];
            int k = g.size();
            d.resize(k);
            d[k - 1] = g[k - 1];
            for (int i = k - 2; i >= 0; --i) {
                if (g[i + 1] == g[i] + 1)
                    d[i] = d[i + 1];
                else
                    d[i] = g[i];
            }
        }
    ll ans = 0;
    for (int r1 = 2; r1 <= n - 1; ++r1) {
        int* RR = R[r1];
        int* LL = L[r1];
        for (int r2 = r1; r2 <= n - 1; ++r2) {
            auto& gh = goodH[r1][r2];
            if (gh.empty()) continue;
            gh.resize(unique(all(gh)) - gh.begin());
            for (int c : gh)
                good[c] = true;
            bhc = 0;           
            for (int c : gh) {
                if (c >= 2 && (!bhc || bh[bhc - 1] < c - 1))
                    bh[bhc++] = c - 1;
                if (!bhc || bh[bhc - 1] < c)
                    bh[bhc++] = c;
                if (c + 1 <= m && (!bhc || bh[bhc - 1] < c + 1))
                    bh[bhc++] = c + 1;
            }
            // debug(r1, r2);
            // debug(gh);
            // mdebug(good+1, m);
            // mdebug(bh, bhc);
            bhl[0] = bh[0];
            for (int i = 1; i < bhc; ++i) {
                if (!good[bh[i] - 1])
                    bhl[i] = bh[i] - 1;
                else
                    bhl[i] = bhl[i - 1];
            }
            bhr[bhc - 1] = bh[bhc - 1];
            for (int i = bhc - 2; i >= 0; --i) {
                if (!good[bh[i] + 1])
                    bhr[i] = bh[i] + 1;
                else
                    bhr[i] = bhr[i + 1];
            }
            // mdebug(bhl, bhc);
            // mdebug(bhr, bhc);
            for (int i = 0; i < bhc; ++i) {
                int c = bh[i];
                int l = LL[c];
                if (bhl[i] <= l && l <= c - 2 && RR[l] != c) {
                    // debug(l, c);
                    auto& gv = goodV[l + 1][c - 1];
                    auto& gvd = goodVD[l + 1][c - 1];
                    auto& ptr = ptrVD[l + 1][c - 1];
                    while (gv[ptr] != r1)
                        ++ptr;
                    ans += gvd[ptr] >= r2;
                    // if (gvd[ptr] >= r2) {
                    //     debug("l", r1, r2, l + 1, c - 1);
                    // }
                }
            }
            for (int i = 0; i < bhc; ++i) {
                int c = bh[i];
                int r = RR[c];
                if (c + 2 <= r && r <= bhr[i]) {
                    // debug(c, r);
                    auto& gv = goodV[c + 1][r - 1];
                    auto& gvd = goodVD[c + 1][r - 1];
                    auto& ptr = ptrVD[c + 1][r - 1];
                    while (gv[ptr] != r1)
                        ++ptr;
                    ans += gvd[ptr] >= r2;
                    // if (gvd[ptr] >= r2) {
                    //     debug("r", r1, r2, c + 1, r - 1);
                    // }
                }
            }
            for (int c : gh)
                good[c] = false;
        }
    }
    return ans;
}

// namespace stu {
//     const int INF32 = 1e9;
// #ifdef DEBUG
//     const int maxN = 30;
// #else
//     const int maxN = 7e2 + 5;
// #endif
// bool badc[maxN];
// int a[maxN][maxN];
// int L[maxN], R[maxN], B[maxN];
// int stk[maxN], sz;
// int mxc[maxN];

// ll count_rectangles(vector<vector<int>> _a) {
//     int n = _a.size(), m = _a[0].size();
//     for (int i = 0; i < n; ++i)
//         for (int j = 0; j < m; ++j)
//             a[i][j] = _a[i][j];
//     if (min(n, m) < 3) return 0;
//     ll ans = 0;
//     for (int r1 = 1; r1 < n - 1; ++r1) {
//         for (int c = 0; c < m; ++c) {
//             mxc[c] = -INF32;
//             L[c] = 0;
//             R[c] = m;
//         }
//         for (int r2 = r1; r2 < n - 1; ++r2) {
//             int *curRow = a[r2];
//             badc[0] = badc[m - 1] = 1;
//             for (int c = 1; c < m - 1; ++c) {
//                 chkmax(mxc[c], curRow[c]);
//                 badc[c] = mxc[c] >= min(a[r1 - 1][c], a[r2 + 1][c]);
//             }
//             // count L
//             sz = 0;
//             for (int c = 0; c < m; ++c) {
//                 while (sz && curRow[stk[sz - 1]] < curRow[c])
//                     --sz;
//                 if (sz)
//                     chkmax(L[c], stk[sz - 1]);
//                 stk[sz] = c;
//                 ++sz;
//             }
//             // count R
//             sz = 0;
//             for (int c = m - 1; c >= 0; --c) {
//                 while (sz && curRow[stk[sz - 1]] < curRow[c])
//                     --sz;
//                 if (sz)
//                     chkmin(R[c], stk[sz - 1]);
//                 stk[sz] = c;
//                 ++sz;
//             }
            
//             int b = m - 1;
//             for (int c = m - 1; c >= 0; --c) {
//                 B[c] = b;
//                 if (badc[c])
//                     b = c;
//             }
//             for (int c = 0; c < m; ++c) {
//                 int l = L[c], r = R[c];
//                 if (r < m && r <= B[c] && r >= c + 2 && L[r] < c)
//                     ++ans;
//                 if (c <= B[l] && c <= R[l] && c >= l + 2)
//                     ++ans;
//             }
//         }
//     }
//     return ans;
// }
// }

#ifdef LOCAL

mt19937 rng(228);
int randint(int l, int r) {
    return rng() % (r - l + 1) + l;
}

int main() {
#ifdef LOCAL
    freopen("in", "r", stdin);
#endif
    
    // for (int test = 1; ; ++test) {
    //     int n = 4, m = 4;
    //     const int C = 3;
    //     vector<vector<int>> a(n, vector<int>(m));
    //     for (auto& i : a)
    //         for (auto& j : i)
    //             j = randint(0, C);
    //     if (count_rectangles(a) != stu::count_rectangles(a)) {
    //         cout << n << ' ' << m << endl;
    //         for (auto& i : a)
    //             cout << i << endl;
    //         exit(0);
    //     }
    //     if (test % 10000 == 0) debug(test);
    // }
    // exit(0);
    
    // InputReader inputReader(STDIN_FILENO);
    int n, m;
    cin >> n >> m;
    // n = inputReader.readInt();
    // m = inputReader.readInt();
    vector<vector<int>> a(n, vector<int>(m));
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            cin >> a[i][j];
            // a[i][j] = inputReader.readInt();
        }
    }
    // inputReader.close();
    
    auto oldTime = Time;
    long long result = count_rectangles(a);
    // debug(Time - oldTime);

    printf("%lld\n", result);
    fclose(stdout);
    return 0;
}
#endif
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...