Submission #1293255

#TimeUsernameProblemLanguageResultExecution timeMemory
1293255thdh__Sumtree (INOI20_sumtree)C++20
100 / 100
939 ms96912 KiB
#include <bits/stdc++.h> #define ll long long #define pb push_back #define eb emplace_back #define pu push #define ins insert #define fi first #define se second #define all(a) a.begin(),a.end() #define bruh ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0); #define fu(x,a,b) for (auto x=a;x<=b;x++) #define fd(x,a,b) for (auto x=a;x>=b;x--) #define int ll using namespace std; //mt19937 mt(chrono::steady_clock::now().time_since_epoch().count()); /* Competitive Programming notes that I need to study & fix my dumbass self: 1. Coding: - Always be sure to check the memory of arrays (maybe use vectors), for loops - Always try to maximize the memory if possible, even if you are going for subtasks - Do not exploit #define int long long, it will kill you 2. Stress: - Always try generating big testcases and try if they run 3. Time management: - Don't overcommit or undercommit, always spend a certain amount of time to think a problem, don't just look at it and say I'm fucked - Do not spend too much time coding brute-force solutions, they should be easily-codable solutions that don't take up too much time Time management schedule: Offline / LAH days (4 problems - 3h): 15' thinking of solution / idea 1. no idea: skip 2. yes idea: continue thinking for <= 15' + implementing: <= 20' + brute-force: <= 5' + test generator: <= 5' I hate offline because I am dumb */ typedef pair<int, int> ii; const int N = 1e6+5; const int B = 750; const int mod = 1e9+7; const int inf = 1e18; using cd = complex<double>; const long double PI = acos(-1); int power(int a,int b) {ll x = 1;if (a >= mod) a%=mod; while (b) {if (b & 1) x = x*a % mod;a = a*a % mod;b>>=1;}return x;} int gt[N], inv[N]; void precalc() { gt[0] = 1; for (int i = 1; i < N; i++) gt[i] = (gt[i-1] * i) % mod; inv[N-1] = power(gt[N-1], mod-2); for (int i = N-2; i >= 0; i--) inv[i] = (inv[i+1] * (i+1)) % mod; } int C(int n, int k) { if (k > n) return 0; if (k == n || !k) return 1; return (gt[n] * inv[k] % mod) * inv[n-k] % mod; } int calc(int n, int r) { return C(n+r-1, n-1); } // Position ST (find the first parent node that is on) int stp[4*N]; void updatep(int id, int l, int r, int i, int val) { if (l == r) { stp[id] = val; return; } int mid = l+r>>1; if (i <= mid) updatep(id*2, l, mid, i, val); else updatep(id*2+1, mid+1, r, i, val); if (stp[id*2+1] != -1) stp[id] = stp[id*2+1]; else stp[id] = stp[id*2]; } int getp(int id, int l, int r, int u, int v) { if (l > v || r < u) return -1; if (u <= l && r <= v) return stp[id]; int mid = l+r>>1; int le = getp(id*2, l, mid, u, v), ri = getp(id*2+1, mid+1, r, u, v); if (ri != -1) return ri; else return le; } // ST sum of sizes int stsz[4*N]; void updatesz(int id, int l, int r, int i, int val) { if (l == r) { stsz[id] = val; return; } int mid = l+r>>1; if (i <= mid) updatesz(id*2, l, mid, i, val); else updatesz(id*2+1, mid+1, r, i, val); stsz[id] = stsz[id*2] + stsz[id*2+1]; } int getsz(int id, int l, int r, int u, int v) { if (l > v || r < u) return 0; if (u <= l && r <= v) return stsz[id]; int mid = l+r>>1; return getsz(id*2, l, mid, u, v) + getsz(id*2+1, mid+1, r, u, v); } // ST sum of values int stv[4*N]; void updatev(int id, int l, int r, int i, int val) { if (l == r) { stv[id] = val; return; } int mid = l+r>>1; if (i <= mid) updatev(id*2, l, mid, i, val); else updatev(id*2+1, mid+1, r, i, val); stv[id] = stv[id*2] + stv[id*2+1]; } int getv(int id, int l, int r, int u, int v) { if (l > v || r < u) return 0; if (u <= l && r <= v) return stv[id]; int mid = l+r>>1; return getv(id*2, l, mid, u, v) + getv(id*2+1, mid+1, r, u, v); } int n,r,q; int a[N], c[N], cnt[N], vl[N]; vector<int> adj[N]; int ans, z = 0; // HLD int sz[N], heavy[N], head[N], tin[N], tout[N], timer = 0, par[N]; void predfs(int u, int p) { par[u] = p; sz[u] = 1; for (auto v : adj[u]) { if (v == p) continue; predfs(v, u); sz[u] += sz[v]; if (sz[v] > sz[heavy[u]]) heavy[u] = v; } } void decompose(int u, int p) { head[u] = p; tin[u] = ++timer; if (heavy[u]) decompose(heavy[u], p); for (auto v : adj[u]) { if (v == par[u] || v == heavy[u]) continue; decompose(v, v); } tout[u] = timer; } int up(int u) // find nearest "on" parent { int pos = -1; while (u) { int tmp = getp(1, 1, n, tin[head[u]], tin[u]); if (tmp != -1) { pos = tmp; break; } u = par[head[u]]; } return pos; } void add(int u, int val) { a[u] = val; int x = up(u); if (c[x]) ans = (ans * power(c[x], mod-2)) % mod; else z--; vl[u] = val - getv(1, 1, n, tin[u], tout[u]); cnt[u] = sz[u] - getsz(1, 1, n, tin[u], tout[u]); vl[x] -= vl[u]; cnt[x] -= cnt[u]; // cout<<u<<" "<<x<<endl; // cout<<vl[u]<<" "<<cnt[u]<<" "<<vl[x]<<" "<<cnt[x]<<endl; updatev(1, 1, n, tin[u], vl[u]); updatev(1, 1, n, tin[x], vl[x]); updatesz(1, 1, n, tin[u], cnt[u]); updatesz(1, 1, n, tin[x], cnt[x]); updatep(1, 1, n, tin[u], u); c[u] = calc(cnt[u], vl[u]); c[x] = calc(cnt[x], vl[x]); if (c[u]) ans = (ans * c[u]) % mod; else z++; if (c[x]) ans = (ans * c[x]) % mod; else z++; } void del(int u) { int x = up(par[u]); if (c[x]) ans = (ans * power(c[x], mod-2)) % mod; else z--; if (c[u]) ans = (ans * power(c[u], mod-2)) % mod; else z--; vl[x] += vl[u]; cnt[x] += cnt[u]; vl[u] = cnt[u] = c[u] = 0; updatev(1, 1, n, tin[u], vl[u]); updatev(1, 1, n, tin[x], vl[x]); updatesz(1, 1, n, tin[u], cnt[u]); updatesz(1, 1, n, tin[x], cnt[x]); updatep(1, 1, n, tin[u], -1); c[x] = calc(cnt[x], vl[x]); if (c[x]) ans = (ans * c[x]) % mod; else z++; a[u] = -1; } void solve() { memset(stp, -1, sizeof(stp)); precalc(); cin>>n>>r; for (int i = 1; i < n; i++) { int u,v; cin>>u>>v; adj[u].pb(v); adj[v].pb(u); } if (n == 1) { cout<<1<<endl; cin>>q; while (q--) { cout<<1<<endl; } return; } for (int i = 1; i <= n; i++) a[i] = -1; a[1] = r; ans = c[1] = calc(n, r); cnt[1] = n; vl[1] = r; updatesz(1, 1, n, 1, cnt[1]); updatev(1, 1, n, 1, vl[1]); updatep(1, 1, n, 1, 1); cout<<ans<<endl; predfs(1, 0); decompose(1, 1); cin>>q; while (q--) { int tp, u; cin>>tp>>u; if (tp == 1) { int v; cin>>v; add(u, v); } else del(u); if (!z) cout<<ans<<endl; else cout<<0<<endl; } } /* Go through the mistakes you usually make and revise your code, for god's sake... */ signed main() { bruh //freopen(".inp","r",stdin); //freopen(".out","w",stdout); int t = 1; // cin>>t; while (t--) { solve(); cout<<"\n"; } }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...