This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int MXN = 1e6 + 5;
const int LOG = 20;
const int mod = 1e9 + 7;
struct FindUp
{
int sz;
vector<int> st;
void init(int n)
{
sz = n;
st.assign((n + 1) << 2, 0);
}
void add(int l, int r, int x, int lx, int rx, int val)
{
if (l > rx || r < lx) return;
if (l >= lx && r <= rx)
{
st[x] += val;
return;
}
int mid = (l + r) >> 1;
add(l, mid, 2*x, lx, rx, val);
add(mid + 1, r, 2*x + 1, lx, rx, val);
}
int get(int l, int r, int x, int ind)
{
if (l == r) return st[x];
int mid = (l + r) >> 1;
if (ind <= mid) return st[x] + get(l, mid, 2*x, ind);
else return st[x] + get(mid + 1, r, 2*x + 1, ind);
}
};
struct FindVAL
{
int sz;
vector<int> st;
void init(int n)
{
sz = n;
st.assign((n + 1) << 2, 0);
}
void add(int l, int r, int x, int ind, int val)
{
if (l == r)
{
st[x] += val;
return;
}
int mid = (l + r) >> 1;
if (ind <= mid) add(l, mid, 2*x, ind, val);
else add(mid + 1, r, 2*x + 1, ind, val);
st[x] = st[2*x] + st[2*x + 1];
}
int get(int l, int r, int x, int lx, int rx)
{
if (l > rx || r < lx) return 0;
if (l >= lx && r <= rx) return st[x];
int mid = (l + r) >> 1;
return get(l, mid, 2*x, lx, rx) + get(mid + 1, r, 2*x + 1, lx, rx);
}
};
int n, r;
vector<int> adj[MXN];
int f[MXN], tag[MXN];
int p[LOG][MXN], sz[MXN];
int in[MXN], out[MXN], tim;
FindUp stup;
FindVAL stsz, sts;
int res = 1, z = 0;
int pw(int a, int b, int c)
{
a %= c;
int res = 1;
while (b)
{
if (b & 1) res = (res * a) % mod;
a = (a * a) % mod;
b >>= 1;
}
return res;
}
int nck(int n, int k)
{
if (min(n, k) < 0 || k > n) return 0;
return (f[n] * pw((f[n - k] * f[k]) % mod, mod - 2, mod)) % mod;
}
void dfs(int a)
{
sz[a] = 1;
in[a] = ++tim;
for (int &v : adj[a])
{
if (v == p[0][a]) continue;
p[0][v] = a;
dfs(v);
sz[a] += sz[v];
}
out[a] = tim;
}
int getco(int u)
{
// cout << u << ' ' << sz[u] << ' ' << stsz.get(1, n, 1, in[u] + 1, out[u]) << ' ' << tag[u] << ' ' << sts.get(1, n, 1, in[u] + 1, out[u]) << '\n';
return nck(sz[u] - stsz.get(1, n, 1, in[u] + 1, out[u]) - 1 + tag[u] - sts.get(1, n, 1, in[u] + 1, out[u]), tag[u] - sts.get(1, n, 1, in[u] + 1, out[u]));
}
int invalid(int u)
{
return (sts.get(1, n, 1, in[u] + 1, out[u]) > tag[u]);
}
int UP(int u)
{
int x = stup.get(1, n, 1, in[u]);
for (int i = LOG - 1; i >= 0; i--)
{
if (stup.get(1, n, 1, in[p[i][u]]) == x) u = p[i][u];
}
return u;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);
f[0] = 1;
for (int i = 1; i < MXN; i++)
{
f[i] = (f[i - 1] * i) % mod;
tag[i] = -1;
}
cin >> n >> r;
for (int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
int q;
cin >> q;
tag[1] = r;
p[0][1] = 1;
dfs(1);
for (int i = 1; i < LOG; i++) for (int j = 1; j <= n; j++) p[i][j] = p[i - 1][p[i - 1][j]];
stup.init(n), stsz.init(n), sts.init(n);
stup.add(1, n, 1, 1, n, 1);
stsz.add(1, n, 1, in[1], 1);
res = getco(1);
cout << res << '\n';
while (q--)
{
int t;
cin >> t;
if (t == 1)
{
int u, v;
cin >> u >> v;
int P = UP(u);
int S = stsz.get(1, n, 1, in[u] + 1, out[u]);
tag[u] = v;
int ff = invalid(P);
if (!ff) res = (res * pw(getco(P), mod - 2, mod)) % mod;
else z--;
stsz.add(1, n, 1, in[u], -S + sz[u]);
stsz.add(1, n, 1, in[P], S - sz[u]);
S = sts.get(1, n, 1, in[u] + 1, out[u]);
sts.add(1, n, 1, in[u], -S + tag[u]);
sts.add(1, n, 1, in[P], S - tag[u]);
ff = invalid(P);
if (!ff) res = (res * getco(P)) % mod;
else z++;
ff = invalid(u);
if (!ff) res = (res * getco(u)) % mod;
else z++;
stup.add(1, n, 1, in[u], out[u], 1);
}
else
{
int u;
cin >> u;
stup.add(1, n, 1, in[u], out[u], -1);
int P = UP(u);
int S = stsz.get(1, n, 1, in[u] + 1, out[u]);
int ff = invalid(P);
if (!ff) res = (res * pw(getco(P), mod - 2, mod)) % mod;
else z--;
ff = invalid(u);
if (!ff) res = (res * pw(getco(u), mod - 2, mod)) % mod;
else z--;
stsz.add(1, n, 1, in[u], S - sz[u]);
stsz.add(1, n, 1, in[P], -S + sz[u]);
S = sts.get(1, n, 1, in[u] + 1, out[u]);
sts.add(1, n, 1, in[u], S - tag[u]);
sts.add(1, n, 1, in[P], -S + tag[u]);
ff = invalid(P);
if (!ff) res = (res * getco(P)) % mod;
else z++;
tag[u] = -1;
}
if (z) cout << 0 << '\n';
else cout << res << '\n';
}
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |