Submission #947680

#TimeUsernameProblemLanguageResultExecution timeMemory
947680PM1Sumtree (INOI20_sumtree)C++17
10 / 100
492 ms275856 KiB
#include <bits/stdc++.h> using namespace std; #define ll long long #define fr first #define sc second const int mxn=5e5+5,M=1e9+7,kkk=(1<<21); int n,q,r,st[mxn],fn[mxn],cnt=0,sz[mxn],g[mxn],bbb; ll ans,fuck[mxn],rfuck[mxn]; vector<int>v[mxn]; ll ferma(ll x){ ll num=M-2,res=1; while(num){ if(num&1)res=(res*x)%M; x=(x*x)%M; num/=2; } return res; } void dfs(int z){ sz[z]=1; st[z]=++cnt; for(auto i:v[z]){ if(!st[i]){ dfs(i); sz[z]+=sz[i]; } } fn[z]=cnt; } void make(){ fuck[0]=rfuck[0]=1; for(int i=1;i<=n+r;i++){ fuck[i]=(fuck[i-1]*i)%M; rfuck[i]=ferma(fuck[i]); } } ll comb(int x,int y,bool w){ if(x>y){ bbb+=(w)?-1:1; return 1; } ll res=fuck[y]*rfuck[x]; res%=M; res*=rfuck[y-x]; return res%M; } struct segment{ ll val[kkk]; void up(int id,int L,int R,int l,int x){ if(L+1==R){ val[id]=x; return; } int mid=(L+R)/2; if(l<mid) up(id*2,L,mid,l,x); else up(id*2+1,mid,R,l,x); val[id]=val[id*2]+val[id*2+1]; } ll get(int id ,int L,int R,int l,int r){ if(l>=R)return 0; if(L==l && R==r) return val[id]; int mid=(L+R)/2; ll res=0; if(l<mid) res+=get(id*2,L,mid,l,min(r,mid)); if(r>mid) res+=get(id*2+1,mid,R,max(l,mid),r); return res; } }seg[2]; struct fnd{ set<int>s[kkk]; int get(int id,int L,int R,int l,int r){ if(L==l && R==r) return ((s[id].size())?*s[id].rbegin():0); int mid=(L+R)/2,res=((s[id].size())?*s[id].rbegin():0),x=0; if(l<mid){ x=get(id*2,L,mid,l,min(r,mid)); } return (x)?x:res; } void add(int id ,int L,int R,int l,int r,int x,bool y){ if(L==l && R==r){ if(y) s[id].insert(x); else s[id].erase(x); return; } int mid=(L+R)/2; if(l<mid) add(id*2,L,mid,l,min(r,mid),x,y); if(r>mid) add(id*2+1,mid,R,max(l,mid),r,x,y); } }fnd; int main(){ ios::sync_with_stdio(false); cin.tie(0); cout.tie(0); cin>>n>>r; for(int i=1;i<n;i++){ int x,y; cin>>x>>y; v[x].push_back(y); v[y].push_back(x); } dfs(1); make(); ans=comb(n-1,n+r-1,0); seg[0].up(1,1,n+1,1,n); seg[1].up(1,1,n+1,1,r); fnd.add(1,1,n+1,1,n+1,1,1); g[1]=r; cout<<ans<<'\n'; cin>>q; while(q--){ int ty,x,y,z; cin>>ty; if(ty==1){ cin>>x>>g[x]; z=fnd.get(1,1,n+1,st[x],fn[x]+1); fnd.add(1,1,n+1,st[x],fn[x]+1,st[x],1); } else{ cin>>x; fnd.add(1,1,n+1,st[x],st[y]+1,st[x],0); z=fnd.get(1,1,n+1,st[x],st[y]+1); } assert(z!=0); int t1=seg[0].get(1,1,n+1,st[x],fn[x]+1)-seg[0].get(1,1,n+1,st[x],st[x]+1); int t2=seg[0].get(1,1,n+1,st[z],fn[z]+1)-seg[0].get(1,1,n+1,st[z],st[z]+1); ll v1=seg[1].get(1,1,n+1,st[x],fn[x]+1); ll v2=seg[1].get(1,1,n+1,st[z],fn[z]+1); if(ty==1){ int sz1=sz[x]-t1; int sz2=sz[z]-t2-sz1; seg[0].up(1,1,n+1,st[x],sz1); seg[0].up(1,1,n+1,st[z],sz2); ans*=ferma(comb(sz[z]-t2-1,sz[z]-t2+v2-1,1)); ans%=M; //cout<<comb(sz[z]-t2-1,sz[z]-t2+v2-1,1)<<" "; int val1=comb(sz1-1,sz1-1+g[x]-v1,0); int val2=comb(sz2-1,sz2-1+v2-g[x]+v1,0); ans*=val1; ans%=M; ans*=val2; ans%=M; seg[1].up(1,1,n+1,st[x],g[x]-v1-v1); seg[1].up(1,1,n+1,st[z],0); ll w=seg[1].get(1,1,n+1,st[z],fn[z]+1); seg[1].up(1,1,n+1,st[z],g[z]-w-w); } else{ int sz1=sz[x]-t1; int sz2=sz[z]-t2+sz1; seg[0].up(1,1,n+1,st[x],0); seg[0].up(1,1,n+1,st[z],sz2); //cout<<v1<<" "<<v2<<'\n'; ans*=ferma(comb(sz1-1,sz1+v1-1,1)); ans%=M; ans*=ferma(comb(sz[z]-t2-1,sz[z]-t2+v2-1,1)); ans%=M; int val1=comb(sz2-1,sz2-1+v2+v1,0); ans*=val1; ans%=M; seg[1].up(1,1,n+1,st[x],0); seg[1].up(1,1,n+1,st[z],0); ll w=seg[1].get(1,1,n+1,st[z],fn[z]+1); seg[1].up(1,1,n+1,st[z],g[z]-w-w); } if(bbb){ cout<<0<<'\n'; } else cout<<ans<<'\n'; } return 0; }

Compilation message (stderr)

Main.cpp: In function 'int main()':
Main.cpp:130:30: warning: 'y' may be used uninitialized in this function [-Wmaybe-uninitialized]
  130 |    fnd.add(1,1,n+1,st[x],st[y]+1,st[x],0);
      |                          ~~~~^
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...