Submission #1325923

#TimeUsernameProblemLanguageResultExecution timeMemory
1325923perchutsKangaroo (CEOI16_kangaroo)C++20
100 / 100
462 ms477796 KiB
#include <bits/stdc++.h>
#define all(x) x.begin(), x.end()
#define sz(x) (int) x.size()
#define pb push_back
#define _ ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
#define int long long
//#define gato

using namespace std;

using ll = long long;
using ull = unsigned long long;
using ii = pair<int,int>;
using iii = tuple<int,int,int>;

const int inf = 2e9+1;
const int mod = 1e9+7;
const int maxn = 3e5+100;

template<typename X, typename Y> bool ckmin(X& x, const Y& y) { return (y < x) ? (x=y,1):0; }
template<typename X, typename Y> bool ckmax(X& x, const Y& y) { return (x < y) ? (x=y,1):0; }

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int rnd(int l, int r) {
    uniform_int_distribution<int> uid(l, r);
    return uid(rng);
}

int ncr[2005][2005];

int n3(int n, int a, int b) {
    if (n == 2) return 1;
    if (a > b) swap(a, b);
    vector f(n+1, vector(n+1, vector(2, 0LL)));
    vector s(n+1, vector(n+1, vector(2, 0LL)));

    for (int i = 1; i <= n; ++i) {
        for (int j = 1; j <= i; ++j) {
            if (i == 1) {
                f[i][i][1] = s[i][i][1] = 1;
                continue;
            }
            f[i][j][0] = s[i-1][j-1][1];
            f[i][j][1] = (s[i-1][i-1][0] - s[i-1][j-1][0] + mod) % mod;
            s[i][j][0] = (s[i][j-1][0] + f[i][j][0]) % mod;
            s[i][j][1] = (s[i][j-1][1] + f[i][j][1]) % mod;
        }
    }

    if (b == n) return (f[n-1][a][0] + f[n-1][a][1]) % mod;

    int ans = 0;
    for (int s1 = 1; s1 <= n-2; ++s1) {
        int s2 = n-1-s1;
        for (int i1 = 1; i1 <= min(a, s1); ++i1) {
            for (int i2 = max(0LL, b-i1-s2); i1 + i2 <= s1 and i2 < b-a; ++i2) {
                int j = b-i1-i2;
                int ways = ncr[a-1][i1-1] * ncr[b-a-1][i2] % mod * ncr[n-1-b][s2-j] % mod;
                if (n & 1) ways = ways * ((f[s1][i1][0] * f[s2][j][0]) % mod + (f[s1][i1][1] * f[s2][j][1]) % mod) % mod;
                else ways = ways * ((f[s1][i1][1] * f[s2][j][0]) % mod + (f[s1][i1][0] * f[s2][j][1]) % mod) % mod;
                ans = (ans + ways) % mod;
            }
        }
    }
    return ans;
}
int solve(int N, int a, int b) {
    if (a > b) swap(a, b);
    // A[n][i][j], i > j?
    // A[n][i][j] = 
    // A[n][i][j], D[n][i][j]
    // A[n][i][j]: primeiro movimento de subida.
    // A[n][i][j] += D[n-1][k][j-1], i <= k <= n-1.
    // D[n][i][j] += A[n-1][k][j-1]  1 <= k < i
    // A[n][i][j] = D[n][n+1-i][n+1-j]
    // A[n][i][j] = (D[n-1][k][j-1], i <= k < j-1) + (D[n-1][k][j-1], k >= j)
    //
    // A[n][i][j] = (D[n-1][k][j-1], i <= k < j-1) + (A[n-1][n-k][n-j+1], k >= j)
    // quero A[n][a][b] + D[n][a][b] = X[n][a][b]
    // X[n][i][j] = A[n][i][j] + D[n][i][j]
    // Y[n][i][j] = A[n][i][j] - D[n][i][j]
    // X[n][i][j] (i < j)
    // X[n][i][j] = (D[n-1][k][j-1], i <= k <= n-1, j-1 != k) + (A[n-1][k][j-1]  1 <= k < i)
    // X[n][i][j] = (D[n-1][k][j-1], j <= k <= n-1) + (D[n-1][k][j-1], i <= k < j-1) + (A[n-1][k][j-1]  1 <= k < i)
    // X[n][i][j] = (D[n-1][k][j-1], k \in {1, ..., n-1} \ {j}) + Y[n-1][k][j-1], 1 <= k < i
    // gostaria de uma dependencia de X em X! como fazer isso? eh possivel??
    //
    // A[n][i][j] = A[n][i-1][j] - D[n-1][i-1][j-1]
    // D[n][i][j] = D[n][i-1][j] + A[n-1][i-1][j-1]
    //
    // X[n][i][j] = A[n][i-1][j] - D[n-1][i-1][j-1] + D[n][i-1][j] + A[n-1][i-1][j-1]
    //
    // X[n][i][j] = X[n][i-1][j] + Y[n-1][i-1][j-1]
    // Y[n][i][j] = Y[n][i-1][j] - X[n-1][i-1][j-1]
    //
    // Y[n-1][i-1][j-1] = Y[n-1][i-2][j-1] - X[n-2][i-2][j-2]
    // X[n][i-1][j] = X[n][i-2][j] + Y[n-1][i-2][j-1]
    // X[n][i-1][j] - Y[n-1][i-1][j-1] = X[n][i-2][j] + X[n-2][i-2][j-2]
    // Y[n-1][i-1][j-1] = X[n][i-1][j] - X[n][i-2][j] - X[n-2][i-2][j-2]
    //
    // X[n][i][j] = 2*X[n][i-1][j] - X[n][i-2][j] - X[n-2][i-2][j-2]
    // -> supondo n >= 3, i >= 3
    vector f(N+1, vector(N+1, vector(2, 0LL)));
    vector s(N+1, vector(N+1, vector(2, 0LL)));

    for (int i = 1; i <= N; ++i) {
        for (int j = 1; j <= i; ++j) {
            if (i == 1) {
                f[i][i][1] = s[i][i][1] = 1;
                continue;
            }
            f[i][j][0] = s[i-1][j-1][1];
            f[i][j][1] = (s[i-1][i-1][0] - s[i-1][j-1][0] + mod) % mod;
            s[i][j][0] = (s[i][j-1][0] + f[i][j][0]) % mod;
            s[i][j][1] = (s[i][j-1][1] + f[i][j][1]) % mod;
        }
    }

    vector X(N+1, vector(N+1, 0));
    int d = N-b;
    for (int n = 2; n <= N; ++n) {
        int j = n-d;
        for (int i = 1; i < j; ++i) {
            if (i == 1) {
                // X[n][1][j] = A[n][1][j] = D[n-1][1][j-1] + D[n-1][2][j-1] + ... + D[n-1][n-1][j-1]
                X[n][i] = (n % 2 ? f[n-1][n+1-j][0] : f[n-1][n+1-j][1]);
            } else if (i == 2) {
                X[n][i] = (X[n-1][1] + X[n][1]) % mod;
                // X[n][2][j] = A[n][2][j] + D[n][2][j] = A[n][2][j] + A[n-1][1][j-1] = A[n][2][j] + X[n-1][1][j-1]
                // A[n][2][j] = D[n-1][2][j-1] + D[n-1][3][j-1] + ... + D[n-1][n-1][j-1]
                // A[n][2][j] = A[n][1][j] - D[n][1][j]
            } else {
				ll k = 2 * X[n][i-1]; k -= X[n][i-2], k -= X[n-2][i-2]; k += 2 * mod; k %= mod;
				X[n][i] = k;
			}
			assert(X[n][i] < mod);
			assert(X[n][i] >= 0);
        }
    }
    return X[N][a];
}

int brute(int n, int a, int b) {
    vector<int> p(n);
    iota(all(p), 1);
    int ans = 0;
    do {
        bool ok = (p[0] == a and p[n-1] == b);
        for (int i = 2; i < n; ++i) ok &= (p[i] < p[i-1]) != (p[i-1] < p[i-2]);
        if (ok) ans++;
    } while(next_permutation(all(p)));
    return ans;
}

int32_t main() {_
    for (int i = 0; i <= 2000; ++i) for (int j = 0; j <= i; ++j) {
        if (j == 0 or j == i) ncr[i][j] = 1;
        else ncr[i][j] = (ncr[i-1][j] + ncr[i-1][j-1]) % mod;
    }
#ifndef gato
    int n, a, b; cin >> n >> a >> b;
    cout << solve(n, a, b) << endl;
#else
    int t = 1;
    while (true) {
        int n = rnd(2, 100), a = rnd(1, n-1), b = rnd(a+1, n);
        int my = solve(n, a, b), ans = n3(n, a, b);
        if (my != ans) {
            cout << "Wrong answer on test " << t << endl;
            cout << n << " " << a << " " << b << endl;
            cout << "Solve: " << my << endl;
            cout << "Brute: " << ans << endl;
            exit(0);
        }
        cout << "Accepted on test " << t++ << endl;
    }
#endif
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...