Submission #723111

#TimeUsernameProblemLanguageResultExecution timeMemory
7231110__0Bi-ing Lottery Treekets (CCO22_day2problem1)C++17
25 / 25
192 ms80460 KiB
#include "bits/stdc++.h" using namespace std; void abc() {cout << endl;} template <typename T, typename ...U> void abc(T a, U ...b) { cout << a << ' ', abc(b...); } template <typename T> void printv(T l, T r) { while (l != r) cout << *l << " \n"[++l == r]; } template <typename A, typename B> istream& operator >> (istream& o, pair<A, B> &a) { return o >> a.first >> a.second; } template <typename A, typename B> ostream& operator << (ostream& o, pair<A, B> a) { return o << '(' << a.first << ", " << a.second << ')'; } template <typename T> ostream& operator << (ostream& o, vector<T> a) { bool is = false; for (T i : a) {o << (is ? ' ' : '{'), is = true, o << i;} return o << '}'; } #ifdef local #define test(args...) abc("[" + string(#args) + "]", args) #else #define test(args...) void(0) #endif using ll = long long; #define int unsigned long long int loc[10000]; int lpar[10000], rpar[10000]; int dp[10000][10000]; // at node i, need j balls from above int above[10000]; //exclusive int below[10000]; //inclusive int sz[10000]; const int MOD = 1e9 + 7; int factorial[10000]; int inverse[10000]; int inversef[10000]; int choose(int a, int b) { if (a < b) return 0LL; return ((factorial[a] % MOD * inversef[b] % MOD) % MOD) * (inversef[a - b]) % MOD; } void dfs(int node, int par = -1, bool left = true) { below[node] = loc[node]; sz[node] = 1; if (lpar[node]) { above[lpar[node]] = above[node] + loc[node]; dfs(lpar[node], node, true); below[node] += below[lpar[node]]; sz[node] += sz[lpar[node]]; } if (rpar[node]) { above[rpar[node]] = above[node] + loc[node]; dfs(rpar[node], node, false); below[node] += below[rpar[node]]; sz[node] += sz[rpar[node]]; } if (sz[node] < below[node]) { cout << 0; exit(0); } int lc = lpar[node]; int rc = rpar[node]; int leftsz = sz[lc] - below[lc]; int rightsz = sz[rc] - below[rc]; int lefttaken = below[lc]; int righttaken = below[rc]; int space = sz[node] - lefttaken - righttaken; if (!left) { swap(lc, rc); swap(leftsz, rightsz); swap(lefttaken, righttaken); } // the root is somewhere else { for (int take = 0; take <= leftsz && take <= loc[node]; take++) { int takeo = loc[node] - take; if (takeo > rightsz) continue; for (int i = 0; i <= above[node] && i + loc[node] <= space; i++) { int amtleft = min(leftsz, take + i); int remain = take + i - amtleft; int amtright = min(rightsz, remain + takeo); if (amtleft + amtright < loc[node]) continue; int constant = 1; if (loc[node] + i == amtleft + amtright + 1) { constant = choose(loc[node], take) % MOD; constant = (constant * choose(i, amtleft - take) % MOD * choose(i - (amtleft - take), 1)); } else { constant = choose(loc[node], take) % MOD; constant = (constant % MOD * choose(i, amtleft - take)); } dp[node][i] %= MOD; int temp = dp[lc][amtleft] * dp[rc][amtright]; temp = temp % MOD; temp = (constant % MOD) * temp; temp = temp % MOD; dp[node][i] = (dp[node][i] + temp) % MOD; if (dp[node][i] < 0) { cout << dp[node][i] << " " << dp[lc][amtleft] << " " << dp[rc][amtright] << " " << constant; exit(0); } } } } // take root with itself { for (int take = 0; take <= leftsz && take < loc[node]; take++) { int takeo = loc[node] - take - 1; if (takeo > rightsz) continue; for (int i = 0; i <= above[node] && i + loc[node] <= space; i++) { int amtleft = min(leftsz, take + i); int remain = take + i - amtleft; int amtright = min(rightsz, remain + takeo); if (amtleft + amtright + 1 < loc[node]) continue; if (i + loc[node] != space) { continue; } int constant; if (loc[node] + i == amtleft + amtright + 1) { constant = (choose(loc[node], take) * choose(loc[node] - take, takeo)) % MOD; constant = (constant * choose(i, amtleft - take)) % MOD; } else { constant = choose(loc[node], take) % MOD; constant = (constant * choose(i, amtleft - take)); } dp[node][i] %= MOD; int temp = dp[lc][amtleft] * dp[rc][amtright]; temp = temp % MOD; temp = (constant % MOD) * temp; temp = temp % MOD; dp[node][i] = (dp[node][i] + temp) % MOD; if (dp[node][i] < 0) { cout << dp[node][i] << " " << dp[lc][amtleft] << " " << dp[rc][amtright] << " " << constant; exit(0); } } } } } int par[10005]; int32_t main() { ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0); // freopen("input.txt", "r", stdin); // freopen("", "w", stdout); int n, k; cin >> n >> k; factorial[0] = 1; for (int i = 1; i <= 9000; i++) { factorial[i] = (factorial[i - 1] * i) % MOD; } inverse[0] = 1; inverse[1] = 1; for (int i = 2; i <= 9000; i++) { inverse[i] = MOD - (MOD / i) * inverse[MOD % i] % MOD; } inversef[0] = 1; for (int i = 1; i <= 9000; i++) { inversef[i] = (inversef[i-1] * inverse[i]) % MOD; } if (k > n) { cout << "0\n"; exit(0); } for (int i = 1; i <= k; i++) { int t; cin >> t; loc[t]++; } for (int i = 1; i <= n; i++) { cin >> lpar[i] >> rpar[i]; par[lpar[i]] = par[rpar[i]] = i; // cout << "IAM " << i << endl; } assert(par[1] == 0); dp[0][0] = 1; dfs(1); cout << dp[1][0] % MOD << "\n"; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...