제출 #999374

#제출 시각아이디문제언어결과실행 시간메모리
999374fuad27Sumtree (INOI20_sumtree)C++17
100 / 100
680 ms88144 KiB
#include <bits/stdc++.h> using namespace std; const long long mod = 1e9+7; const int N = 4e5+10; const int MAXNC=5e5+10; const int LG = 20; vector<int> g[N]; int tin[N], tout[N], tim=0; int up[N][LG], tags[N]; long long fact[MAXNC]; int depth[N], sz[N], used[N], tagsms[N]; namespace BIT { int fen[N]; void add(int at, int vl) { at++; while(at < N) { fen[at] += vl; at+=at&(-at); } } int query(int at) { at++; int res=0; while(at) { res+=fen[at]; at-=at&(-at); } return res; } }; void change(int x, int vl) { BIT::add(tin[x], vl); BIT::add(tout[x], -vl); } int query(int anc, int ch) { return BIT::query(tin[ch])-BIT::query(tin[anc]); } long long binpow(long long a, long long b) { long long res = 1; while(b) { if(b&1) { res=(res*a)%mod; } a=(a*a)%mod; b>>=1; } return res; } long long ncr(long long nn, long long rr) { if(rr > nn)return 0; if(rr < 0)return 0; long long res = fact[nn]; res = (res*binpow(fact[rr], mod-2))%mod; res = (res*binpow(fact[nn-rr], mod-2))%mod; return res; } void dfs(int at, int p) { sz[at]=1; up[at][0] = p; for(int i = 1;i<LG;i++)up[at][i] = up[up[at][i-1]][i-1]; tin[at] = tim++; for(int to:g[at]) { if(to == p)continue; depth[to]=depth[at]+1; dfs(to, at); sz[at]+=sz[to]; } tout[at] = tim++; } struct DS1 { int fen[3*N]; DS1() { for(int i = 0;i<3*N;i++)fen[i]=0; } void upd(int at, int val) { at++; while(at<3*N) { fen[at]+=val; at+=at&(-at); } } void update(int l, int r, int val) { upd(l, val); upd(r+1, -val); } long long get(int at) { at++; long long res=0; while(at) { res+=fen[at]; at-=at&(-at); } return res; } }; namespace HLD1 { DS1 ds; int tin[N], top[N], tim=0, par[N], V[N]; void dfs(int at, int p, int tp) { tin[at]=++tim; top[at]=tp; par[at]=p; int big=-1; for(int to:g[at]) { if(to==p)continue; if(big==-1 or sz[to]>sz[big]) { big=to; } } if(big==-1)return; dfs(big, at, tp); for(int to:g[at]) { if(to==big or to==p)continue; dfs(to, at, to); } } int upd(int u,int v, long long val) { int ans=0; while(top[u]!=top[v]) { if(depth[top[u]] < depth[top[v]])swap(u,v); ds.update(tin[top[u]], tin[u], val); u=par[top[u]]; } if(depth[u]>depth[v])swap(u,v); ds.update(tin[u], tin[v], val); return ans; } long long get(int u) { return ds.get(tin[u]); } }; namespace HLD2 { DS1 ds; int tin[N], top[N], tim=0, par[N], V[N]; void dfs(int at, int p, int tp) { tin[at]=++tim; top[at]=tp; par[at]=p; int big=-1; for(int to:g[at]) { if(to==p)continue; if(big==-1 or sz[to]>sz[big]) { big=to; } } if(big==-1)return; dfs(big, at, tp); for(int to:g[at]) { if(to==big or to==p)continue; dfs(to, at, to); } } int upd(int u,int v, int val) { int ans=0; while(top[u]!=top[v]) { if(depth[top[u]] < depth[top[v]])swap(u,v); ds.update(tin[top[u]], tin[u], val); u=par[top[u]]; } if(depth[u]>depth[v])swap(u,v); ds.update(tin[u], tin[v], val); return ans; } int get(int u) { return ds.get(tin[u]); } }; int findroot(int at) { for(int i = LG-1;i>=0;i--) { if(up[at][i] != 0 and query(up[at][i], at) == depth[at]-depth[up[at][i]]) { at = up[at][i]; } } return at; } struct ANS { long long mult = 1; long long zero = 0; void add(long long vl) { if(vl == 0) zero++; else mult = (mult*vl)%mod; } void del(long long vl) { if(vl == 0) { zero--; } else { mult = (mult*binpow(vl, mod-2))%mod; } } long long get() { if(zero>0)return 0; return mult; } }; int main () { cin.tie(0)->sync_with_stdio(0); fact[0] = 1; for(int i = 1;i<MAXNC;i++)fact[i] = (fact[i-1]*i)%mod; int n, r; cin >> n >> r; for(int i = 1;i<n;i++) { int u, v; cin >> u >> v; g[u].push_back(v); g[v].push_back(u); } dfs(1, 0); HLD1::dfs(1, 0, 1); HLD2::dfs(1, 0, 1); tags[1] = r; used[1] = 1; for(int i = 2;i<=n;i++){ change(i, 1); } for(int i= 1;i<=n;i++) { HLD1::upd(i, i, sz[i]); HLD2::upd(i, i, tagsms[i]); } ANS ans; ans.add(ncr(n+r-1, r)); cout << ans.get() << "\n"; int q; cin >> q; while(q--) { int t; cin >> t; if(t == 1) { int u, v; cin >> u >> v; tags[u]=v; int rr = findroot(u); { long long tmp = tags[rr] - HLD2::get(rr); long long szz = HLD1::get(rr); ans.del(ncr(szz+tmp-1, tmp)); } int par = up[u][0]; HLD1::upd(par, rr, -HLD1::get(u)); HLD2::upd(par, rr, v-HLD2::get(u)); // while(1) { // sz[par]-=sz[u]; // tagsms[par]-=tagsms[u]; // tagsms[par]+=v; // if(par==rr)break; // par = up[par][0]; // } { long long tmp = tags[u] - HLD2::get(u); long long szz = HLD1::get(u); ans.add(ncr(szz+tmp-1, tmp)); } { long long tmp = tags[rr] - HLD2::get(rr); long long szz = HLD1::get(rr); ans.add(ncr(szz+tmp-1, tmp)); } change(u, -1); } else { int u; cin >> u; int v = tags[u]; change(u, 1); int rr = findroot(u); { long long tmp = tags[rr] - HLD2::get(rr); long long szz = HLD1::get(rr); ans.del(ncr(szz+tmp-1, tmp)); } int par = up[u][0]; HLD1::upd(par, rr, HLD1::get(u)); HLD2::upd(par, rr, HLD2::get(u)-v); // while(1) { // sz[par]+=sz[u]; // tagsms[par]+=tagsms[u]; // tagsms[par]-=v; // if(par==rr)break; // par = up[par][0]; // } { long long tmp = tags[u] - HLD2::get(u); long long szz = HLD1::get(u); ans.del(ncr(szz+tmp-1, tmp)); } { long long tmp = tags[rr] - HLD2::get(rr); long long szz = HLD1::get(rr); ans.add(ncr(szz+tmp-1, tmp)); } } cout << ans.get() << "\n"; } }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...