Submission #999536

#TimeUsernameProblemLanguageResultExecution timeMemory
999536aykhnSumtree (INOI20_sumtree)C++17
100 / 100
1220 ms154136 KiB
#include <bits/stdc++.h> using namespace std; #define int long long const int MXN = 1e6 + 5; const int LOG = 20; const int mod = 1e9 + 7; struct FindUp { int sz; vector<int> st; void init(int n) { sz = n; st.assign((n + 1) << 2, 0); } void add(int l, int r, int x, int lx, int rx, int val) { if (l > rx || r < lx) return; if (l >= lx && r <= rx) { st[x] += val; return; } int mid = (l + r) >> 1; add(l, mid, 2*x, lx, rx, val); add(mid + 1, r, 2*x + 1, lx, rx, val); } int get(int l, int r, int x, int ind) { if (l == r) return st[x]; int mid = (l + r) >> 1; if (ind <= mid) return st[x] + get(l, mid, 2*x, ind); else return st[x] + get(mid + 1, r, 2*x + 1, ind); } }; struct FindVAL { int sz; vector<int> st; void init(int n) { sz = n; st.assign((n + 1) << 2, 0); } void add(int l, int r, int x, int ind, int val) { if (l == r) { st[x] += val; return; } int mid = (l + r) >> 1; if (ind <= mid) add(l, mid, 2*x, ind, val); else add(mid + 1, r, 2*x + 1, ind, val); st[x] = st[2*x] + st[2*x + 1]; } int get(int l, int r, int x, int lx, int rx) { if (l > rx || r < lx) return 0; if (l >= lx && r <= rx) return st[x]; int mid = (l + r) >> 1; return get(l, mid, 2*x, lx, rx) + get(mid + 1, r, 2*x + 1, lx, rx); } }; int n, r; vector<int> adj[MXN]; int f[MXN], tag[MXN]; int p[LOG][MXN], sz[MXN]; int in[MXN], out[MXN], tim; FindUp stup; FindVAL stsz, sts; int res = 1, z = 0; int pw(int a, int b, int c) { a %= c; int res = 1; while (b) { if (b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1; } return res; } int nck(int n, int k) { if (min(n, k) < 0 || k > n) return 0; return (f[n] * pw((f[n - k] * f[k]) % mod, mod - 2, mod)) % mod; } void dfs(int a) { sz[a] = 1; in[a] = ++tim; for (int &v : adj[a]) { if (v == p[0][a]) continue; p[0][v] = a; dfs(v); sz[a] += sz[v]; } out[a] = tim; } int getco(int u) { // cout << u << ' ' << sz[u] << ' ' << stsz.get(1, n, 1, in[u] + 1, out[u]) << ' ' << tag[u] << ' ' << sts.get(1, n, 1, in[u] + 1, out[u]) << '\n'; return nck(sz[u] - stsz.get(1, n, 1, in[u] + 1, out[u]) - 1 + tag[u] - sts.get(1, n, 1, in[u] + 1, out[u]), tag[u] - sts.get(1, n, 1, in[u] + 1, out[u])); } int invalid(int u) { return (sts.get(1, n, 1, in[u] + 1, out[u]) > tag[u]); } int UP(int u) { int x = stup.get(1, n, 1, in[u]); for (int i = LOG - 1; i >= 0; i--) { if (stup.get(1, n, 1, in[p[i][u]]) == x) u = p[i][u]; } return u; } signed main() { ios::sync_with_stdio(0); cin.tie(0); f[0] = 1; for (int i = 1; i < MXN; i++) { f[i] = (f[i - 1] * i) % mod; tag[i] = -1; } cin >> n >> r; for (int i = 1; i < n; i++) { int u, v; cin >> u >> v; adj[u].push_back(v); adj[v].push_back(u); } int q; cin >> q; tag[1] = r; p[0][1] = 1; dfs(1); for (int i = 1; i < LOG; i++) for (int j = 1; j <= n; j++) p[i][j] = p[i - 1][p[i - 1][j]]; stup.init(n), stsz.init(n), sts.init(n); stup.add(1, n, 1, 1, n, 1); stsz.add(1, n, 1, in[1], 1); res = getco(1); cout << res << '\n'; while (q--) { int t; cin >> t; if (t == 1) { int u, v; cin >> u >> v; int P = UP(u); int S = stsz.get(1, n, 1, in[u] + 1, out[u]); tag[u] = v; int ff = invalid(P); if (!ff) res = (res * pw(getco(P), mod - 2, mod)) % mod; else z--; stsz.add(1, n, 1, in[u], -S + sz[u]); stsz.add(1, n, 1, in[P], S - sz[u]); S = sts.get(1, n, 1, in[u] + 1, out[u]); sts.add(1, n, 1, in[u], -S + tag[u]); sts.add(1, n, 1, in[P], S - tag[u]); ff = invalid(P); if (!ff) res = (res * getco(P)) % mod; else z++; ff = invalid(u); if (!ff) res = (res * getco(u)) % mod; else z++; stup.add(1, n, 1, in[u], out[u], 1); } else { int u; cin >> u; stup.add(1, n, 1, in[u], out[u], -1); int P = UP(u); int S = stsz.get(1, n, 1, in[u] + 1, out[u]); int ff = invalid(P); if (!ff) res = (res * pw(getco(P), mod - 2, mod)) % mod; else z--; ff = invalid(u); if (!ff) res = (res * pw(getco(u), mod - 2, mod)) % mod; else z--; stsz.add(1, n, 1, in[u], S - sz[u]); stsz.add(1, n, 1, in[P], -S + sz[u]); S = sts.get(1, n, 1, in[u] + 1, out[u]); sts.add(1, n, 1, in[u], S - tag[u]); sts.add(1, n, 1, in[P], -S + tag[u]); ff = invalid(P); if (!ff) res = (res * getco(P)) % mod; else z++; tag[u] = -1; } if (z) cout << 0 << '\n'; else cout << res << '\n'; } }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...