제출 #403026

#제출 시각아이디문제언어결과실행 시간메모리
40302612tqianTents (JOI18_tents)C++17
100 / 100
325 ms35904 KiB
#include <bits/stdc++.h>

using namespace std;

#define f1r(i, a, b) for (int i = a; i < b; ++i)
#define f0r(i, a) f1r(i, 0, a)
#define each(t, a) for (auto& t : a)

#define mp make_pair
#define f first
#define s second
#define pb push_back
#define eb emplace_back
#define sz(x) (int)(x).size()
#define all(x) begin(x), end(x)

typedef long long ll;
typedef vector<int> vi;
typedef vector<ll> vl;
typedef pair<int, int> pi;
typedef pair<ll, ll> pl;
typedef vector<pi> vpi;
typedef vector<pl> vpl;

template <class T> bool ckmin(T& a, const T& b) { return b < a ? a = b, 1 : 0; }
template <class T> bool ckmax(T& a, const T& b) { return a < b ? a = b, 1 : 0; }

const int N = 3005;
const int P = 1e9 + 7;

int mpow(ll b, ll e) {
    ll r = 1;
    while (e) {
        if (e & 1) {
            r *= b;
            r %= P;
        }
        b *= b;
        b %= P;
        e >>= 1;
    }
    return r;
}
int minv(ll b) { return mpow(b, P - 2); }
int add(int x, int y) { x += y; if (x >= P) x -= P; return x; }
int sub(int x, int y) { x -= y; if (x < 0) x += P; return x; }
int mult(int x, int y) { return 1LL * x * y % P; }
int madd(int&x, int y) { return x = add(x, y); }
int mmult(int& x, int y) { return x = mult(x, y); }
int msub(int& x, int y) { return x = sub(x, y); }
int fact[N], ifact[N];

int dp[N][N];

int C(int x, int y) { 
    if (x < y) return 0;
    return mult(fact[x], mult(ifact[y], ifact[x - y]));
}

int solve(int n, int m) {
    auto& res = dp[n][m];
    if (res != -1) return res;
    res = 1; // no more pairs
    if (n == 0 || m == 0) {
        return res;
    }
    { // first thing isn't paired
        madd(res, sub(solve(n - 1, m), 1));
    }
    if (m >= 2) { // first thing is double paired
        madd(res, mult(C(m, 2), solve(n - 1, m - 2)));
    }
    if (n >= 2) { // first thing is double paired with same side
        madd(res, mult(mult(n - 1, m), solve(n - 2, m - 1)));
    }
    return res;
}

int main() {
    cin.tie(0)->sync_with_stdio(0);
    fact[0] = 1;
    f1r(i, 1, N) fact[i] = mult(fact[i - 1], i);
    ifact[N - 1] = minv(fact[N - 1]);
    f0r(i, N) f0r(j, N) dp[i][j] = -1;
    for (int i = N - 2; i >= 0; i--) ifact[i] = mult(ifact[i + 1], i + 1);
    int n, m; cin >> n >> m;
    int ans = 0;
    f0r(i, min(m, n) + 1) { // number of singles
        int res = mult(fact[i], mult(C(n, i), C(m, i)));
        mmult(res, mpow(4, i));
        int a = n - i;
        int b = m - i;
        madd(ans, mult(res, solve(a, b))); 
    }
    msub(ans, 1);
    cout << ans << '\n';
    return 0;
}

/**
 * each vertex has at most degree 2
 * no path of length greater than 3
 * 4^(number of single components)
 * 
 */
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...