답안 #723033

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
723033 2023-04-13T07:42:18 Z 0__0 Bi-ing Lottery Treekets (CCO22_day2problem1) C++17
0 / 25
76 ms 32528 KB
#include "bits/stdc++.h"

using namespace std;

void abc() {cout << endl;}
template <typename T, typename ...U> void abc(T a, U ...b) {
    cout << a << ' ', abc(b...);
}
template <typename T> void printv(T l, T r) {
    while (l != r) cout << *l << " \n"[++l == r];
}
template <typename A, typename B> istream& operator >> (istream& o, pair<A, B> &a) {
    return o >> a.first >> a.second;
}
template <typename A, typename B> ostream& operator << (ostream& o, pair<A, B> a) {
    return o << '(' << a.first << ", " << a.second << ')';
}
template <typename T> ostream& operator << (ostream& o, vector<T> a) {
    bool is = false;
    for (T i : a) {o << (is ? ' ' : '{'), is = true, o << i;}
    return o << '}';
}

#ifdef local
#define test(args...) abc("[" + string(#args) + "]", args)
#else
#define test(args...) void(0)
#endif

using ll = long long;

int loc[4005];
int lpar[4005], rpar[4005];
ll dp[4005][4005]; // at node i, need j balls from above
int above[4005]; //exclusive
int below[4005]; //inclusive
int sz[4005];

const ll MOD = 1e9 + 7;

ll factorial[4005];
ll inverse[4005];

ll choose(int a, int b) {
    return ((factorial[a] % MOD * inverse[b] % MOD) % MOD) * (inverse[a - b]) % MOD;
}

void dfs(int node, int par = -1, bool left = true) {

    below[node] = loc[node];
    sz[node] = 1;

    if (lpar[node]) {
        above[lpar[node]] = above[node] + loc[node];
        dfs(lpar[node], node, true);
        below[node] += below[lpar[node]];
        sz[node] += sz[lpar[node]];
    }
    if (rpar[node]) {
        above[rpar[node]] = above[node] + loc[node];
        dfs(rpar[node], node, false);
        below[node] += below[rpar[node]];
        sz[node] += sz[rpar[node]];
    }

    if (sz[node] < below[node]) {
        cout << -1 << '\n';
        exit(0);
    }


//    if (lpar[node] ==0  && rpar[node] == 0) {
//        if (loc[node] > 1) {
//            cout << 0 << "\n";
//            exit(0);
//        }
//
//        if (lpar[node] || rpar[node]) {
//            dp[node][0] = 1;
//        } else {
//            dp[node][1] = 1;
//        }
//
//        return;
//    }

    int lc = lpar[node];
    int rc = rpar[node];
    int leftsz = sz[lc] - below[lc];
    int rightsz = sz[rc] - below[rc];
    int lefttaken = below[lc];
    int righttaken = below[rc];
    int space = sz[node] - lefttaken - righttaken;

    if (!left) {
        swap(lc, rc);
        swap(leftsz, rightsz);
        swap(lefttaken, righttaken);
    }

    // the root is somewhere else
    {
        for (int take = 0; take <= leftsz && take <= loc[node]; take++) {
            int takeo = loc[node] - take;

            for (int i = 0; i <= above[node] && i + loc[node] <= space; i++) {
                int amtleft = min(leftsz, take + i);
                int remain = take + i - amtleft;
                int amtright = min(rightsz, remain);

                if (amtleft + amtright < loc[node]) continue;

                ll constant = 1;

                if (loc[node] + i == amtleft + amtright + 1) {
                    constant = choose(loc[node], take) % MOD;
                    constant = (constant * choose(i, amtleft - take) % MOD * choose(i - (amtleft - take), (amtright - takeo)));
                } else {
                    constant = choose(loc[node], take) % MOD;
                    constant = (constant % MOD * choose(i, amtleft - take));
                }

                dp[node][i] = (dp[node][i] + ((dp[lc][amtleft] * dp[rc][amtright]) % MOD) * constant) % MOD;
            }
        }
    }

    // take root with itself
    {
        for (int take = 0; take <= leftsz && take < loc[node]; take++) {
            int takeo = loc[node] - take - 1;

            for (int i = 0; i <= above[node] && i + loc[node] <= space; i++) {

                int amtleft = min(leftsz, take + i);
                int remain = take + i - amtleft;
                int amtright = min(rightsz, remain);

                if (amtleft + amtright + 1 < loc[node]) continue;
                if (i + loc[node] != space) {
                    continue;
                }
                
                ll constant = 1;

                if (loc[node] + i == amtleft + amtright + 1) {
                    constant = (choose(loc[node], take) * choose(loc[node] - take, takeo)) % MOD;
                    constant = (constant * choose(i, amtleft - take)) % MOD;
                } else {
                    constant = choose(loc[node], take) % MOD;
                    constant = (constant * choose(i, amtleft - take));
                }

                dp[node][i] = (dp[node][i] + ((dp[lc][amtleft] * dp[rc][amtright]) % MOD) * constant) % MOD;
            }
        }
    }
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
//    freopen("", "r", stdin);
//    freopen("", "w", stdout);
    int n, k; cin >> n >> k;
    factorial[0] = 1;
    for (int i = 1; i <= 4004; i++) {
        factorial[i] = (factorial[i - 1] * i) % MOD;
    }

    inverse[0] = 1;
    inverse[1] = 1;
    for (int i = 2; i <= 4004; i++) {
        inverse[i] = MOD - (MOD / i) * inverse[MOD % i] % MOD;
    }

    if (k > n) {
        cout << "0\n";
        exit(0);
    }
    for (int i = 1; i <= k; i++) {
        int t; cin >> t;
        loc[t]++;
    }
    for (int i = 1; i <= n; i++) cin >> lpar[i] >> rpar[i];

    dp[0][0] = 1;
    dfs(1);

    cout << dp[1][0] % MOD << "\n";
}
# 결과 실행 시간 메모리 Grader output
1 Incorrect 0 ms 340 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 14 ms 17488 KB Output is correct
2 Incorrect 76 ms 32528 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 11 ms 16624 KB Output is correct
2 Correct 12 ms 17364 KB Output is correct
3 Correct 11 ms 16708 KB Output is correct
4 Incorrect 13 ms 16764 KB Output isn't correct
5 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 0 ms 340 KB Output isn't correct
2 Halted 0 ms 0 KB -