Submission #999103

#TimeUsernameProblemLanguageResultExecution timeMemory
999103vjudge1Sumtree (INOI20_sumtree)C++17
10 / 100
549 ms302932 KiB
#include <bits/stdc++.h>

using namespace std;

#define int long long

const int MXN = 1e6 + 5;
const int LOG = 20;
const int mod = 1e9 + 7;

struct FindUp
{
	int sz;
	vector<int> st;
	void init(int n)
	{
		sz = n;
		st.assign((n + 1) << 2, 0);
	}
	void add(int l, int r, int x, int lx, int rx, int val)
	{
		if (l > rx || r < lx) return;
		if (l >= lx && r <= rx)
		{
			st[x] += val;
			return;
		}
		int mid = (l + r) >> 1;
		add(l, mid, 2*x, lx, rx, val);
		add(mid + 1, r, 2*x + 1, lx, rx, val);
	}
	int get(int l, int r, int x, int ind)
	{
		if (l == r) return st[x];
		int mid = (l + r) >> 1;
		if (ind <= mid) return st[x] + get(l, mid, 2*x, ind);
		else return st[x] + get(mid + 1, r, 2*x + 1, ind);
	}
};
struct FindVAL
{
	int sz;
	vector<int> st;
	void init(int n)
	{
		sz = n;
		st.assign((n + 1) << 2, 0);
	}
	void add(int l, int r, int x, int ind, int val)
	{
		if (l == r)
		{
			st[x] += val;
			return;
		}
		int mid = (l + r) >> 1;
		if (ind <= mid) add(l, mid, 2*x, ind, val);
		else add(mid + 1, r, 2*x + 1, ind, val);
		st[x] = st[2*x] + st[2*x + 1];
	}
	int get(int l, int r, int x, int lx, int rx)
	{
		if (l > rx || r < lx) return 0;
		if (l >= lx && r <= rx) return st[x];
		int mid = (l + r) >> 1;
		return get(l, mid, 2*x, lx, rx) + get(mid + 1, r, 2*x + 1, lx, rx);
	}
};

int n, r;
vector<int> adj[MXN];
int f[MXN], tag[MXN];
int p[LOG][MXN], sz[MXN];
int in[MXN], out[MXN], tim;
FindUp stup;
FindVAL stsz, sts;
int res = 1, z = 0;

int pw(int a, int b, int c)
{
	a %= c;
	int res = 1;
	while (b)
	{
		if (b & 1) res = (res * a) % mod;
		a = (a * a) % mod;
		b >>= 1;
	}
	return res;
}

int nck(int n, int k)
{
	if (min(n, k) < 0 || k > n) return 0;
	return (f[n] * pw((f[n - k] * f[k]) % mod, mod - 2, mod)) % mod;
}

void dfs(int a)
{
	sz[a] = 1;
	in[a] = ++tim;
	for (int &v : adj[a])
	{
		if (v == p[0][a]) continue;
		p[0][v] = a;
		dfs(v);
		sz[a] += sz[v];
 	}
	out[a] = tim;
}

int getco(int u)
{
	// cout << u << ' ' << sz[u] << ' ' << stsz.get(1, n, 1, in[u] + 1, out[u]) << ' ' << tag[u] << ' ' << sts.get(1, n, 1, in[u] + 1, out[u]) << '\n';
	return nck(sz[u] - stsz.get(1, n, 1, in[u] + 1, out[u]) - 1 + tag[u] - sts.get(1, n, 1, in[u] + 1, out[u]), tag[u] - sts.get(1, n, 1, in[u] + 1, out[u]));
}

int invalid(int u)
{
	return sts.get(1, n, 1, in[u] + 1, out[u]) > tag[u];
}

int UP(int u)
{
	int x = stup.get(1, n, 1, in[u]);
	for (int i = LOG - 1; i >= 0; i--)
	{
		if (stup.get(1, n, 1, in[p[i][u]]) == x) u = p[i][u];
	}
	return p[0][u];
}


signed main()
{
	ios::sync_with_stdio(0);
	cin.tie(0);
	f[0] = 1;
	for (int i = 1; i < MXN; i++) 
	{
		f[i] = (f[i - 1] * i) % mod;
		tag[i] = -1;
	}
	cin >> n >> r;
	for (int i = 1; i < n; i++)
	{
		int u, v;
		cin >> u >> v;
		adj[u].push_back(v);
		adj[v].push_back(u);
	}
	int q;
	cin >> q;
	tag[1] = r;
	p[0][1] = 1;
	dfs(1);
	for (int i = 1; i < LOG; i++) for (int j = 1; j <= n; j++) p[i][j] = p[i - 1][p[i - 1][j]];
	stup.init(n), stsz.init(n), sts.init(n);
	stup.add(1, n, 1, 1, n, 1);
	stsz.add(1, n, 1, in[1], 1);
	res = getco(1);
	cout << res << '\n';
	while (q--)
	{
		int t;
		cin >> t;
		if (t == 1)
		{
			int u, v;
			cin >> u >> v;
			int P = UP(u);
			int S = stsz.get(1, n, 1, in[u] + 1, out[u]);
			tag[u] = v;
			int ff = invalid(P);
			if (!ff) res = (res * pw(getco(P), mod - 2, mod)) % mod;
			else z--;
			stsz.add(1, n, 1, in[u], -S + sz[u]);
			stsz.add(1, n, 1, in[P], S - sz[u]);
			
			S = sts.get(1, n, 1, in[u] + 1, out[u]);
			sts.add(1, n, 1, in[u], -S + tag[u]);
			sts.add(1, n, 1, in[P], S - tag[u]);
			ff = invalid(P);
			if (!ff) res = (res * getco(P)) % mod;
			else z++;
			ff = invalid(u);
			if (!ff) res = (res * getco(u)) % mod;
			else z++;

			stup.add(1, n, 1, in[u], out[u], 1);
		}
		else
		{
			int u;
			cin >> u;
			int P = UP(u);
			int S = stsz.get(1, n, 1, in[u] + 1, out[u]);
			int ff = invalid(P);
			if (!ff) res = (res * pw(getco(P), mod - 2, mod)) % mod;
			else z--;
			ff = invalid(u);
			if (!ff) res = (res * pw(getco(u), mod - 2, mod)) % mod;
			else z--;
			stsz.add(1, n, 1, in[u], S - sz[u]);
			stsz.add(1, n, 1, in[P], -S + sz[u]);

			S = sts.get(1, n, 1, in[u] + 1, out[u]);
			sts.add(1, n, 1, in[u], S - tag[u]);
			sts.add(1, n, 1, in[P], -S + tag[u]);
			ff = invalid(P);
			if (!ff) res = (res * getco(P)) % mod;
			else z++;

			stup.add(1, n, 1, in[u], out[u], -1);
			tag[u] = -1;
		}
		if (z) cout << 0 << '\n';
		else cout << res << '\n';
	}
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...