Submission #780653

#TimeUsernameProblemLanguageResultExecution timeMemory
780653Sami_MassahSumtree (INOI20_sumtree)C++17
100 / 100
2261 ms303656 KiB
#include <bits/stdc++.h> using namespace std; const int maxn = 5e5 + 12, maxk = 2e5 + 12, lg = 18, mod = 1e9 + 7; int n, bs, tz, tms, h[maxk], col[maxk], par[maxk][lg], sz[maxk], st[maxk], en[maxk], sum1[maxk * 3], sum2[maxk * 3]; long long ans, fact[maxn], rfact[maxn]; set <int> Q[maxk * 3]; vector <int> conn[maxk]; bitset <maxn> marked, zero; long long tav(long long a, long long b){ if(b == 0) return 1; a %= mod; long long x = tav(a * a % mod, b / 2); if(b % 2) return x * a % mod; return x % mod; } void update_tree1(int l, int r, int u, int k, int L = 0, int R = n){ if(r < L || R < l) return; if(l <= L && R <= r){ sum1[u] = k; return; } int mid = (L + R) / 2; update_tree1(l, r, u * 2, k, L, mid); update_tree1(l, r, u * 2 + 1, k, mid + 1, R); sum1[u] = sum1[u * 2] + sum1[u * 2 + 1]; } void update_tree2(int l, int r, int u, int k, int L = 0, int R = n){ if(r < L || R < l) return; if(l <= L && R <= r){ sum2[u] = k; return; } int mid = (L + R) / 2; update_tree2(l, r, u * 2, k, L, mid); update_tree2(l, r, u * 2 + 1, k, mid + 1, R); sum2[u] = sum2[u * 2] + sum2[u * 2 + 1]; } int get_sum1(int l, int r, int u, int L = 0, int R = n){ if(r < L || R < l) return 0; if(l <= L && R <= r) return sum1[u]; int mid = (L + R) / 2; return get_sum1(l, r, u * 2, L, mid) + get_sum1(l, r, u * 2 + 1, mid + 1, R); } int get_sum2(int l, int r, int u, int L = 0, int R = n){ if(r < L || R < l) return 0; if(l <= L && R <= r) return sum2[u]; int mid = (L + R) / 2; return get_sum2(l, r, u * 2, L, mid) + get_sum2(l, r, u * 2 + 1, mid + 1, R); } int kpar(int u, int k){ for(int i = 0; i < lg; i++) if((k >> i) & 1) u = par[u][i]; return u; } void dfs_set(int u){ marked[u] = 1; sz[u] = 1; st[u] = tms; tms += 1; for(int i = 0; i + 1 < lg; i++) par[u][i + 1] = par[par[u][i]][i]; for(int v: conn[u]) if(!marked[v]){ par[v][0] = u; h[v] = h[u] + 1; dfs_set(v); sz[u] += sz[v]; } en[u] = tms - 1; } void add_to_tree(int l, int r, int u, int k, int L = 0, int R = n){ if(r < L || R < l) return; if(l <= L && R <= r){ Q[u].insert(h[k]); return; } int mid = (L + R) / 2; add_to_tree(l, r, u * 2, k, L, mid); add_to_tree(l, r, u * 2 + 1, k, mid + 1, R); } void erase_from_tree(int l, int r, int u, int k, int L = 0, int R = n){ if(r < L || R < l) return; if(l <= L && R <= r){ Q[u].erase(Q[u].lower_bound(h[k])); return; } int mid = (L + R) / 2; erase_from_tree(l, r, u * 2, k, L, mid); erase_from_tree(l, r, u * 2 + 1, k, mid + 1, R); } int find_pd(int l, int r, int u, int L = 0, int R = n){ if(r < L || R < l) return -1; if(l <= L && R <= r){ if(Q[u].size() == 0) return -1; return *Q[u].rbegin(); } auto x = -1; if(Q[u].size()) x = *Q[u].rbegin(); int mid = (L + R) / 2; return max({find_pd(l, r, u * 2, L, mid), find_pd(l, r, u * 2 + 1, mid + 1, R), x}); } long long get_c(int a, int b){ if(a < b) return 0; return fact[a] * (rfact[b] * rfact[a - b] % mod) % mod; } void add_tree(int u, int k){ int pd = find_pd(st[u], st[u], 1); // cout << pd << endl; pd = kpar(u, h[u] - pd); // cout << u << ' ' << pd << endl; // cout << u << ' ' << pd << endl; col[u] = k; if(u != 1){ int x = get_sum1(st[pd] + 1, en[pd], 1); int d = get_sum2(st[pd] + 1, en[pd], 1); if(zero[pd] == 0){ x = sz[pd] - x; d = col[pd] - d; long long f = get_c(d + x - 1, x - 1); ans = ans * tav(f, mod - 2) % mod; } } // cout << ans << endl; int x = get_sum1(st[u], en[u], 1); int d = get_sum2(st[u], en[u], 1); update_tree1(st[u], st[u], 1, sz[u] - x); update_tree2(st[u], st[u], 1, k - d); add_to_tree(st[u] + 1, en[u], 1, u); // cout << st[pd] << '-' << en[pd] << endl; if(k < d){ zero[u] = 1; tz += 1; } else{ // cout << x << ' ' << sz[u] << endl; x = sz[u] - x; d = k - d; // cout << x << ' ' << d << endl; long long f = get_c(d + x - 1, x - 1); ans = ans * f % mod; } // cout << ans << endl; if(u != 1){ int x = get_sum1(st[pd] + 1, en[pd], 1); int d = get_sum2(st[pd] + 1, en[pd], 1); update_tree1(st[pd], st[pd], 1, sz[pd] - x); update_tree2(st[pd], st[pd], 1, col[pd] - d); if(col[pd] < d){ tz += (1 - zero[pd]); zero[pd] = 1; } else{ x = sz[pd] - x; d = col[pd] - d; tz -= zero[pd]; zero[pd] = 0; // cout << x << ' ' << d << endl; long long f = get_c(d + x - 1, x - 1); ans = ans * f % mod; } } //cout << ans << endl << endl; } void remove_tree(int u){ int pd = find_pd(st[u], st[u], 1); // cout << pd << endl; pd = kpar(u, h[u] - pd); // cout << u << ' ' << pd << endl; int x = get_sum1(st[pd] + 1, en[pd], 1); int d = get_sum2(st[pd] + 1, en[pd], 1); if(zero[pd] == 0){ x = sz[pd] - x; d = col[pd] - d; long long f = get_c(d + x - 1, x - 1); ans = ans * tav(f, mod - 2) % mod; } x = get_sum1(st[u] + 1, en[u], 1); d = get_sum2(st[u] + 1, en[u], 1); update_tree1(st[u], st[u], 1, 0); update_tree2(st[u], st[u], 1, 0); erase_from_tree(st[u] + 1, en[u], 1, u); //cout << get_sum1(1, n, 1) << endl; if(col[u] < d){ tz -= zero[u]; zero[u] = 0; } else{ tz -= zero[u]; zero[u] = 0; x = sz[u] - x; d = col[u] - d; // cout << d << ' ' << x << endl; long long f = get_c(d + x - 1, x - 1); ans = ans * tav(f, mod - 2) % mod; } col[u] = -1; x = get_sum1(st[pd] + 1, en[pd], 1); d = get_sum2(st[pd] + 1, en[pd], 1); update_tree1(st[pd], st[pd], 1, sz[pd] - x); update_tree2(st[pd], st[pd], 1, col[pd] - d); if(col[pd] < d){ tz += (1 - zero[pd]); zero[pd] = 1; } else{ tz -= (zero[pd]); zero[pd] = 0; x = sz[pd] - x; d = col[pd] - d; // cout << d << ' ' << x << endl; long long f = get_c(d + x - 1, x - 1); ans = ans * f % mod; } } int main(){ ios_base::sync_with_stdio(false), cin.tie(0); memset(col, -1, sizeof col); col[0] = 0; fact[0] = 1; for(int i = 1; i < maxn; i++) fact[i] = fact[i - 1] * i % mod; for(int i = 0; i < maxn; i++) rfact[i] = tav(fact[i], mod - 2) % mod; cin >> n >> bs; for(int i = 0; i < n - 1; i++){ int a, b; cin >> a >> b; conn[a].push_back(b); conn[b].push_back(a); } cout << endl; dfs_set(1); ans = 1; add_tree(1, bs); cout << ans << "\n"; int q; cin >> q; for(int i = 0; i < q; i++){ int a; int b, c; cin >> a; if(a == 1){ cin >> b >> c; add_tree(b, c); } else{ cin >> b; remove_tree(b); } if(tz) cout << 0 << "\n"; else cout << ans << "\n"; } }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...