답안 #765359

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
765359 2023-06-24T11:31:59 Z abysmal Sandcastle 2 (JOI22_ho_t5) C++14
0 / 100
15 ms 784 KB
#include<iostream>
#include<stdio.h>
#include<stdint.h>
#include<iomanip>
#include<algorithm>
#include<utility>
#include<vector>
#include<stack>
#include<queue>
#include<set>
#include<map>
#include<deque>
#include<math.h>
#include<assert.h>

using namespace std;

const int64_t INF = (int64_t) 2 * 1e9 + 100;
const int64_t mINF = (int64_t) 1e9 + 5;
const int64_t MOD = 1e9 + 7;
const int nbit = 31;
const int nchar = 26;
const int D = 4;
int dr[D] = {-1, 0, 1, 0};
int dc[D] = {0, 1, 0, -1};

struct Cell
{
    int r;
    int c;

    Cell(int r_, int c_) : r(r_), c(c_) {}
};

struct Solution
{
    int n;
    int m;
    int nmask;
    vector<int> pow;
    vector<vector<int> > g;
    vector<vector<vector<int> > > a;
    Solution() {}

    void solve()
    {
        cin >> n >> m;
        g.resize(n, vector<int>(m, 0));

        for(int i = 0; i < n; i++)
        {
            for(int j = 0; j < m; j++)
            {
                cin >> g[i][j];
            }
        }

        if(n == 1) Sub1();
        else Sub4();
    }

    void Sub1()
    {
        int t = 0;
        int cnt = 1;
        int64_t ans = 0;
        for(int i = 1; i < m; i++)
        {
            int x = 1;
            if(g[0][i] < g[0][i - 1]) x = -1;
            if(t != x && t != 0)
            {
                ans += cnt * (cnt - 1) / 2;
                cnt = 1;
            }

            cnt++;
            t = x;
        }

        ans += cnt * (cnt - 1) / 2;
        cout << ans + m << "\n";
    }

    void Sub4()
    {
        int k = 3;
        nmask = 81;
        pow.resize(k + 1, 1);
        for(int i = 1; i <= k; i++)
        {
            pow[i] = pow[i - 1] * 3;
        }

        a.resize(nmask, vector<vector<int> >(n + 1, vector<int>(m + 1, 0)));
        for(int mask = 0; mask < nmask; mask++)
        {
            vector<int> c;
            int tmp = mask;
            for(int i = k; i >= 0; i--)
            {
                int d = tmp / pow[i];
                tmp %= pow[i];

                c.push_back(d);
            }
            reverse(c.begin(), c.end());
            for(int i = 0; i < n; i++)
            {
                for(int j = 0; j < m; j++)
                {
                    if(!check(c, i, j)) continue;

                    int cnt = 0;
                    for(int t = 0; t < D; t++)
                    {
                        if(c[t] == 0) continue;

                        Cell v(i + dr[t], j + dc[t]);
                        if(g[v.r][v.c] < g[i][j]) continue;

                        int max_ = g[i][j];
                        if(c[t] == 2)
                        {
                            int val = g[v.r + dr[t]][v.c + dc[t]];
                            if(val < g[v.r][v.c]) max_ = max(max_, val);
                        }

                        int nx = t + 1; if(nx == D) nx -= D;
                        int pr = t - 1; if(pr < 0) pr += D;
                        if(c[nx] != 0)
                        {
                            int val = g[v.r + dr[nx]][v.c + dc[nx]];
                            if(val < g[v.r][v.c]) max_ = max(max_, val);
                        }
                        if(c[pr] != 0)
                        {
                            int val = g[v.r + dr[pr]][v.c + dc[pr]];
                            if(val < g[v.r][v.c]) max_ = max(max_, val);
                        }

                        if(max_ == g[i][j]) cnt++;
                    }

                    if(cnt == 0) a[mask][i][j]++;
                }
            }
        }

        vector<vector<vector<int> > > prefix(nmask, vector<vector<int> >(n + 1, vector<int>(m + 1, 0)));
        for(int mask = 0; mask < nmask; mask++)
        {
            for(int i = 1; i <= n; i++)
            {
                for(int j = 1; j <= m; j++)
                {
                    prefix[mask][i][j] -= prefix[mask][i - 1][j - 1];
                    prefix[mask][i][j] += a[mask][i - 1][j - 1] + prefix[mask][i - 1][j] + prefix[mask][i][j - 1];
                }
            }
        }

        int ans = 0;
        for(int i = 0; i < n; i++)
        {
            for(int j = 0; j < m; j++)
            {
                for(int r = i; r < n; r++)
                {
                    for(int c = j; c < m; c++)
                    {
                        int sum = 0;
                        int lx = i; int rx = r;
//                        cerr << "(" << i << "," << j << ")\n";
                         cerr << "(" << r << "," << c << ")\n";
                        while(lx <= rx)
                        {
                            int t1 = getMask(i, j, r, c, lx, c);
                            int t2 = getMask(i, j, r, c, rx, c);
                            if(t1 == t2 && lx != rx)
                            {
                                int lly = j; int lry = c;
                                int rly = j; int rry = c;
                                while(lly <= lry)
                                {
                                    int t3 = getMask(i, j, r, c, lx, lly);
                                    int t4 = getMask(i, j, r, c, lx, lry);

                                    if(t3 == t4) sum += prefix[t3][rx + 1][rry + 1] - prefix[t3][rx + 1][rly] - prefix[t3][lx][lry] + prefix[t3][lx][lly];
                                    else
                                    {
                                        sum += prefix[t3][rx + 1][rly + 1] - prefix[t3][rx + 1][rly] - prefix[t3][lx][lly + 1] + prefix[t3][lx][lly];
                                        sum += prefix[t4][rx + 1][rry + 1] - prefix[t4][rx + 1][rry] - prefix[t4][lx][lry + 1] + prefix[t3][lx][lry];
                                        break;
                                    }

                                    lly++; lry--;
                                    rly++; rry--;
                                }
                                break;
                            }
                            else
                            {
                                int lly = j; int lry = c;
                                int rly = j; int rry = c;
                                while(lly <= lry)
                                {
                                    int t3 = getMask(i, j, r, c, lx, lly);
                                    int t4 = getMask(i, j, r, c, lx, lry);
                                    int t5 = getMask(i, j, r, c, rx, rly);
                                    int t6 = getMask(i, j, r, c, rx, rry);
//                                    cerr << "t3 = " << t3 << " ; t4 = " << t4 << "\n";
                                    if(t3 != t4)
                                    {
                                        sum += a[t3][lx][lly] + a[t4][lx][lry];
                                        if(lx != rx) sum += a[t5][rx][rly] + a[t6][rx][rry];
                                    }
                                    else
                                    {
                                        sum += prefix[t3][lx + 1][lry + 1] - prefix[t3][lx + 1][lly] - prefix[t3][lx][lry + 1] + prefix[t3][lx][lly];
                                        if(lx != rx) sum += prefix[t5][rx + 1][rry + 1] - prefix[t5][rx + 1][rly] - prefix[t5][rx][rry + 1] + prefix[t5][rx][rly];
                                        break;
                                    }

                                    lly++; lry--;
                                    rly++; rry--;
                                }
                            }
                            lx++; rx--;
                        }
//                        cerr << "\n";
                        if(sum == 1) ans++;
                    }
                }
            }
        }

        cout << ans << "\n";
    }

    int getMask(int i, int j, int r, int c, int x, int y)
    {
        int u = min(x - i, 2); int d = min(r - x, 2);
        int l = min(y - j, 2); int rt = min(c - y, 2);
        return u * pow[0] + rt * pow[1] + d * pow[2] + l * pow[3];
    }

    bool check(vector<int>& p, int r, int c)
    {
        return 0 <= r - p[0] && r + p[2] < n &&
               0 <= c - p[3] && c + p[1] < m;
    }

    int modadd(int t1, int t2)
    {
        int res = t1 + t2;
        if(res >= MOD) res -= MOD;
        return res;
    }

    int modmul(int t1, int t2)
    {
        int64_t res = 1LL * t1 * t2;
        return res % MOD;
    }

    int Abs(int tmp)
    {
        if(tmp < 0) return -tmp;
        return tmp;
    }

    int64_t MASK(int j)
    {
        return (1LL << j);
    }

    bool BIT(int64_t tmp, int j)
    {
        return (tmp & MASK(j));
    }
};

void __setup()
{
    ios_base::sync_with_stdio(false);
    cin.tie(NULL); cout.tie(NULL);

//    freopen("in.txt", "r", stdin);
//    freopen("out.txt", "w", stdout);
}

int main()
{
    __setup();

    int t = 1;
//    cin >> t;
    for(int i = 1; i <= t; i++)
    {
        Solution().solve();
    }
    return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 212 KB Output is correct
2 Incorrect 4 ms 784 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 15 ms 524 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 15 ms 524 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 15 ms 524 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 15 ms 524 KB Output isn't correct
2 Halted 0 ms 0 KB -