답안 #1036901

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1036901 2024-07-27T19:25:47 Z PM1 Sumtree (INOI20_sumtree) C++17
컴파일 오류
0 ms 0 KB
#include <bits/stdc++.h>
using namespace std;
#define fr first
#define sc second
#define ll long long
const int mxn=5e5+5,M=1e9+7,sz=(1<<21);
int n,r,q;
int st[mxn],fn[mxn],cnt,par[mxn],zr[mxn],bad=0,pos[mxn];
ll f[mxn],rf[mxn],ans=1;
vector<int>v[mxn];
ll ferma(ll x,ll num=M-2){
	ll res=1;
	while(num){
		if(num&1)res=(res*x)%M;
		x=(x*x)%M;
		num>>=1;
	}
	return res;
}
void dfs(int z){
	st[z]=++cnt;
	pos[cnt]=z;
	zr[z]=1;
	for(auto i:v[z]){
		if(par[z]!=i){
			par[i]=z;
			dfs(i);
			zr[z]+=zr[i];
		}
	}
	fn[z]=cnt;
}
ll comb(ll x,ll y,int z){
	if(x>y){
		bad+=z;
		return 1;
	}
	//cout<<x<<" "<<y<<'\n';
	ll res=f[y]*rf[x];
	res%=M;
	res*=rf[y-x];
	return res%M;
}
struct segment{
	int val[sz];
	int get(int id,int L,int R,int l,int r){
		if(L>=R || L>=r || l>=r || l>=R)
			return 0;
		if(L>=l && R<=r)
			return val[id];
		int mid=(L+R)>>1,res=0;
		res+=get(id<<1,L,mid,l,r);
		res+=get((id<<1)+1,mid,R,l,r);
		return res;
	}
	void add(int id,int L,int R,int l,int x){
		if(L+1==R){
			val[id]+=x;
			return ;
		}
		int mid=(L+R)>>1;
		if(l<mid)
			add(id<<1,L,mid,l,x);
		else
			add((id<<1)+1,mid,R,l,x);
		val[id]=val[id*2]+val[id*2+1];
		return ;
	}
}seg[2];
struct kirkhar{
	set<int>s[sz];
	int get(int id,int L,int R,int l){
		if(L+1==R)
			return (s[id].size())?*s[id].rbegin():0;
		int mid=(L+R)>>1,res;
		if(l<mid)
			res=get(id<<1,L,mid,l);
		else
			res=get((id<<1)+1,mid,R,l);
		if(res==0)
			return (s[id].size())?*s[id].rbegin():0;
		return res ;
	}
	void add(int id,int L,int R,int l,int r,int x,bool kos){
		if(L>=R || L>=r || l>=r || l>=R)
			return ;
		if(L>=l && R<=r){
			if(kos)
				s[id].insert(x);
			else
				s[id].erase(x);
			return ;
		}
		int mid=(L+R)>>1,res=0;
		add(id<<1,L,mid,l,r,x,kos);
		add((id<<1)+1,mid,R,l,r,x,kos);
		return ;
	}
}gg;
int main(){
	//ios::sync_with_stdio(false);
	//cin.tie(0);
	//cout.tie(0);
	cin>>n>>r;
	for(int i=1;i<n;i++){
		int x,y;
		cin>>x>>y;
		v[x].push_back(y);
		v[y].push_back(x);
	}
	dfs(1);
	f[0]=1;
	rf[0]=1;
	for(int i=1;i<=n+r;i++){
		f[i]=(f[i-1]*i)%M;
		rf[i]=ferma(f[i]);
		//cout<<f[i]<<" "<<rf[i]<<'\n';
	}
	seg[0].add(1,1,n+1,1,r);
	seg[1].add(1,1,n+1,1,n);
	gg.add(1,1,n+1,1,n+1,1,1);
	ans*=comb(n-1,n-1+r);
	ans%=M;
	cout<<ans<<'\n';
	cin>>q;
	while(q--){
		ll ty,x,y;
		cin>>ty>>x;
		if(ty==1){
			int ycnt=zr[x];
			cin>>y;
			ll zir=seg[0].get(1,1,n+1,st[x]+1,fn[x]+1);
			ll zircnt=seg[1].get(1,1,n+1,st[x]+1,fn[x]+1);
			y-=zir;
			ycnt-=zircnt;
			int p=gg.get(1,1,n+1,st[x]);
			p=pos[p];
			ll pre=seg[0].get(1,1,n+1,st[p],fn[p]+1)-seg[0].get(1,1,n+1,st[p]+1,fn[p]+1);
			ll precnt=seg[1].get(1,1,n+1,st[p],fn[p]+1)-seg[1].get(1,1,n+1,st[p]+1,fn[p]+1);
			ans*=ferma(comb(precnt-1,precnt-1+pre,-1));
			ans%=M;
			
			pre-=y;
			precnt-=ycnt;
			//cout<<y-zir<<" "<<ycnt-zircnt<<'\n';
			seg[0].add(1,1,n+1,st[x],y);
			seg[1].add(1,1,n+1,st[x],ycnt);
			seg[0].add(1,1,n+1,st[p],-y);
			seg[1].add(1,1,n+1,st[p],-ycnt);
			ans*=comb(precnt-1,precnt+pre-1,1);
			ans%=M;
			ans*=comb(ycnt-1,ycnt+y-1,1);
			ans%=M;
			gg.add(1,1,n+1,st[x],fn[x]+1,st[x],1);
		}
		else{
			gg.add(1,1,n+1,st[x],fn[x]+1,st[x],0);
			ll zir=seg[0].get(1,1,n+1,st[x],fn[x]+1)-seg[0].get(1,1,n+1,st[x]+1,fn[x]+1);
			ll zircnt=seg[1].get(1,1,n+1,st[x],fn[x]+1)-seg[1].get(1,1,n+1,st[x]+1,fn[x]+1);
			int p=gg.get(1,1,n+1,st[x]);
			p=pos[p];
			ll pre=seg[0].get(1,1,n+1,st[p],fn[p]+1)-seg[0].get(1,1,n+1,st[p]+1,fn[p]+1);
			ll precnt=seg[1].get(1,1,n+1,st[p],fn[p]+1)-seg[1].get(1,1,n+1,st[p]+1,fn[p]+1);
			ans*=ferma(comb(precnt-1,precnt-1+pre,-1));
			ans%=M;
			ans*=ferma(comb(zircnt-1,zircnt+zir-1,-1));
			ans%=M;
			pre+=zir;
			precnt+=zircnt;
			seg[0].add(1,1,n+1,st[x],-zir);
			seg[1].add(1,1,n+1,st[x],-zircnt);
			seg[0].add(1,1,n+1,st[p],+zir);
			seg[1].add(1,1,n+1,st[p],+zircnt);			
			ans*=comb(precnt-1,precnt+pre-1,1);
			ans%=M;
			
		}
		if(bad)
			cout<<0<<'\n';
		else
			cout<<ans<<'\n';
	}
	return 0;
 
}

Compilation message

Main.cpp: In member function 'void kirkhar::add(int, int, int, int, int, int, bool)':
Main.cpp:94:20: warning: unused variable 'res' [-Wunused-variable]
   94 |   int mid=(L+R)>>1,res=0;
      |                    ^~~
Main.cpp: In function 'int main()':
Main.cpp:122:21: error: too few arguments to function 'long long int comb(long long int, long long int, int)'
  122 |  ans*=comb(n-1,n-1+r);
      |                     ^
Main.cpp:33:4: note: declared here
   33 | ll comb(ll x,ll y,int z){
      |    ^~~~