제출 #1155637

#제출 시각아이디문제언어결과실행 시간메모리
1155637antonnSandcastle 2 (JOI22_ho_t5)C++20
80 / 100
5094 ms36128 KiB
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

int dx[] = {0, 0, -1, 1};
int dy[] = {-1, 1, 0, 0};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    int h, w;
    cin >> h >> w;
    vector<vector<int>> a(h + 1, vector<int>(w + 1));
    for (int i = 1; i <= h; ++i) {
        for (int j = 1; j <= w; ++j) {
            cin >> a[i][j];
        }
    }
    if (h > w) {
        vector<vector<int>> b(w + 1, vector<int>(h + 1));
        for (int i = 1; i <= h; ++i) {
            for (int j = 1; j <= w; ++j) {
                b[j][i] = a[i][j];
            }
        }
        swap(a, b);
        swap(h, w);
    }
    
    vector<int> vals;
    for (int i = 1; i <= h; ++i) {
        for (int j = 1; j <= w; ++j) {
            vals.push_back(a[i][j]);
        }
    }
    sort(vals.begin(), vals.end());
    map<int, int> id;
    for (int i = 0; i < h * w; ++i) {
        id[vals[i]] = i + 1;
    }
    for (int i = 1; i <= h; ++i) {
        for (int j = 1; j <= w; ++j) {
            a[i][j] = id[a[i][j]];
        }
    }
    
    vector<vector<vector<int>>> lower(16, vector<vector<int>>(h + 1, vector<int>(w + 1)));
    vector<vector<vector<int>>> upper(16, vector<vector<int>>(h + 1, vector<int>(w + 1)));
    vector<vector<vector<int>>> val(16, vector<vector<int>>(h + 1, vector<int>(w + 1)));
    for (int i = 1; i <= h; ++i) {
        for (int j = 1; j <= w; ++j) {
            vector<int> neigh;
            for (int d = 0; d < 4; ++d) {
                if (i + dx[d] >= 1 && i + dx[d] <= h && j + dy[d] >= 1 && dy[d] <= w) {
                    neigh.push_back(a[i + dx[d]][j + dy[d]]);
                } else {
                    neigh.push_back(-1);
                }
            }
            for (int mask = 0; mask < (1 << 4); ++mask) {
                bool ok = 1;
                for (int b = 0; b < 4; ++b) {
                    if (mask & (1 << b)) {
                        ok &= (neigh[b] != -1);
                    }
                }
                if (ok == 0) {
                    continue;
                }
                
                lower[mask][i][j] = 0;
                upper[mask][i][j] = h * w + 1;
                for (int b = 0; b < 4; ++b) {
                    if (mask & (1 << b)) {
                        int x = neigh[b];
                        if (x < a[i][j]) lower[mask][i][j] = max(lower[mask][i][j], x);
                        if (x > a[i][j]) upper[mask][i][j] = min(upper[mask][i][j], x);
                    }
                }
                val[mask][i][j] = a[i][j] - lower[mask][i][j] + (upper[mask][i][j] == h * w + 1) * (h * w + 1 - a[i][j]);
            }
        }
    }
    
    vector<vector<vector<ll>>> sum(16, vector<vector<ll>>(h + 1, vector<ll>(w + 1)));
    for (int mask = 0; mask < (1 << 4); ++mask) {
        for (int i = 1; i <= h; ++i) {
            for (int j = 1; j <= w; ++j) {
                sum[mask][i][j] = sum[mask][i - 1][j] + sum[mask][i][j - 1] - sum[mask][i - 1][j - 1] + val[mask][i][j];
            }
        }
    }
    
    auto get = [&](int mask, int x1, int y1, int x2, int y2) -> ll {
        if (x1 > x2 || y1 > y2) return 0;
        return sum[mask][x2][y2] - sum[mask][x1 - 1][y2] - sum[mask][x2][y1 - 1] + sum[mask][x1 - 1][y1 - 1];
    };
    
    auto get_sum = [&](int x1, int y1, int x2, int y2) {
        ll ans = 0;
        if (x1 == x2 && y1 == y2) {
            ans = val[0][x1][y1];
            return ans;
        }
        if (x1 == x2) {
            ans += val[2][x1][y1]; 
            if (y1 != y2) ans += val[1][x2][y2];
            if (y1 + 1 <= y2 - 1) ans += get(3, x1, y1 + 1, x2, y2 - 1);
            return ans;
        }
        if (y1 == y2) {
            ans += val[8][x1][y1];
            if (x1 != x2) ans += val[4][x2][y2];
            if (x1 + 1 <= x2 - 1) ans += get(12, x1 + 1, y1, x2 - 1, y2);
            return ans; 
        }
        ans += val[10][x1][y1];
        ans += val[9][x1][y2];
        ans += val[6][x2][y1];
        ans += val[5][x2][y2];
        if (y1 + 1 <= y2 - 1) ans += get(11, x1, y1 + 1, x1, y2 - 1);
        if (y1 + 1 <= y2 - 1) ans += get(7, x2, y1 + 1, x2, y2 - 1);
        if (x1 + 1 <= x2 - 1) ans += get(14, x1 + 1, y1, x2 - 1, y1);
        if (x1 + 1 <= x2 - 1) ans += get(13, x1 + 1, y2, x2 - 1, y2);
        ans += get(15, x1 + 1, y1 + 1, x2 - 1, y2 - 1);
        return ans;
    };

    int ans = 0;
    for (int i = 1; i <= h; ++i) {
        for (int j = 1; j <= w; ++j) {
            for (int x = i; x <= h; ++x) {
                for (int y = j; y <= w; ++y) {
                    if (get_sum(i, j, x, y) == h * w + 1) {
                        ++ans;
                    }
                }
            }
        }
    }
    cout << ans << "\n";
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...