#include <bits/stdc++.h>
using namespace std;
#define int long long
const int MXN = 1e6 + 5;
const int LOG = 20;
const int mod = 1e9 + 7;
struct FindUp
{
int sz;
vector<int> st;
void init(int n)
{
sz = n;
st.assign((n + 1) << 2, 0);
}
void add(int l, int r, int x, int lx, int rx, int val)
{
if (l > rx || r < lx) return;
if (l >= lx && r <= rx)
{
st[x] += val;
return;
}
int mid = (l + r) >> 1;
add(l, mid, 2*x, lx, rx, val);
add(mid + 1, r, 2*x + 1, lx, rx, val);
}
int get(int l, int r, int x, int ind)
{
if (l == r) return st[x];
int mid = (l + r) >> 1;
if (ind <= mid) return st[x] + get(l, mid, 2*x, ind);
else return st[x] + get(mid + 1, r, 2*x + 1, ind);
}
};
struct FindVAL
{
int sz;
vector<int> st;
void init(int n)
{
sz = n;
st.assign((n + 1) << 2, 0);
}
void add(int l, int r, int x, int ind, int val)
{
if (l == r)
{
st[x] += val;
return;
}
int mid = (l + r) >> 1;
if (ind <= mid) add(l, mid, 2*x, ind, val);
else add(mid + 1, r, 2*x + 1, ind, val);
st[x] = st[2*x] + st[2*x + 1];
}
int get(int l, int r, int x, int lx, int rx)
{
if (l > rx || r < lx) return 0;
if (l >= lx && r <= rx) return st[x];
int mid = (l + r) >> 1;
return get(l, mid, 2*x, lx, rx) + get(mid + 1, r, 2*x + 1, lx, rx);
}
};
int n, r;
vector<int> adj[MXN];
int f[MXN], tag[MXN];
int p[LOG][MXN], sz[MXN];
int in[MXN], out[MXN], tim;
FindUp stup;
FindVAL stsz, sts;
int res = 1, z = 0;
int pw(int a, int b, int c)
{
a %= c;
int res = 1;
while (b)
{
if (b & 1) res = (res * a) % mod;
a = (a * a) % mod;
b >>= 1;
}
return res;
}
int nck(int n, int k)
{
if (min(n, k) < 0 || k > n)
{
assert(0);
return 0;
}
return (f[n] * pw((f[n - k] * f[k]) % mod, mod - 2, mod)) % mod;
}
void dfs(int a)
{
sz[a] = 1;
in[a] = ++tim;
for (int &v : adj[a])
{
if (v == p[0][a]) continue;
p[0][v] = a;
dfs(v);
sz[a] += sz[v];
}
out[a] = tim;
}
int getco(int u)
{
// cout << u << ' ' << sz[u] << ' ' << stsz.get(1, n, 1, in[u] + 1, out[u]) << ' ' << tag[u] << ' ' << sts.get(1, n, 1, in[u] + 1, out[u]) << '\n';
return nck(sz[u] - stsz.get(1, n, 1, in[u] + 1, out[u]) - 1 + tag[u] - sts.get(1, n, 1, in[u] + 1, out[u]), tag[u] - sts.get(1, n, 1, in[u] + 1, out[u]));
}
int invalid(int u)
{
return sts.get(1, n, 1, in[u] + 1, out[u]) > tag[u];
}
int UP(int u)
{
int x = stup.get(1, n, 1, in[u]);
for (int i = LOG - 1; i >= 0; i--)
{
if (stup.get(1, n, 1, in[p[i][u]]) == x) u = p[i][u];
}
return p[0][u];
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);
f[0] = 1;
for (int i = 1; i < MXN; i++)
{
f[i] = (f[i - 1] * i) % mod;
tag[i] = -1;
}
cin >> n >> r;
for (int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
int q;
cin >> q;
tag[1] = r;
p[0][1] = 1;
dfs(1);
for (int i = 1; i < LOG; i++) for (int j = 1; j <= n; j++) p[i][j] = p[i - 1][p[i - 1][j]];
stup.init(n), stsz.init(n), sts.init(n);
stup.add(1, n, 1, 1, n, 1);
stsz.add(1, n, 1, in[1], 1);
res = getco(1);
cout << res << '\n';
while (q--)
{
int t;
cin >> t;
if (t == 1)
{
int u, v;
cin >> u >> v;
int P = UP(u);
int S = stsz.get(1, n, 1, in[u] + 1, out[u]);
tag[u] = v;
int ff = invalid(P);
if (!ff) res = (res * pw(getco(P), mod - 2, mod)) % mod;
else z--;
stsz.add(1, n, 1, in[u], -S + sz[u]);
stsz.add(1, n, 1, in[P], S - sz[u]);
S = sts.get(1, n, 1, in[u] + 1, out[u]);
sts.add(1, n, 1, in[u], -S + tag[u]);
sts.add(1, n, 1, in[P], S - tag[u]);
ff = invalid(P);
if (!ff) res = (res * getco(P)) % mod;
else z++;
ff = invalid(u);
if (!ff) res = (res * getco(u)) % mod;
else z++;
stup.add(1, n, 1, in[u], out[u], 1);
}
else
{
int u;
cin >> u;
int P = UP(u);
int S = stsz.get(1, n, 1, in[u] + 1, out[u]);
int ff = invalid(P);
if (!ff) res = (res * pw(getco(P), mod - 2, mod)) % mod;
else z--;
ff = invalid(u);
if (!ff) res = (res * pw(getco(u), mod - 2, mod)) % mod;
else z--;
stsz.add(1, n, 1, in[u], S - sz[u]);
stsz.add(1, n, 1, in[P], -S + sz[u]);
S = sts.get(1, n, 1, in[u] + 1, out[u]);
sts.add(1, n, 1, in[u], S - tag[u]);
sts.add(1, n, 1, in[P], -S + tag[u]);
ff = invalid(P);
if (!ff) res = (res * getco(P)) % mod;
else z++;
stup.add(1, n, 1, in[u], out[u], -1);
tag[u] = -1;
}
if (z) cout << 0 << '\n';
else cout << res << '\n';
}
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
85 ms |
155428 KB |
Output is correct |
2 |
Correct |
97 ms |
155352 KB |
Output is correct |
3 |
Correct |
89 ms |
155476 KB |
Output is correct |
4 |
Correct |
98 ms |
155472 KB |
Output is correct |
5 |
Correct |
78 ms |
151892 KB |
Output is correct |
6 |
Correct |
16 ms |
88664 KB |
Output is correct |
7 |
Correct |
16 ms |
88412 KB |
Output is correct |
8 |
Correct |
16 ms |
88412 KB |
Output is correct |
9 |
Correct |
85 ms |
150776 KB |
Output is correct |
10 |
Correct |
85 ms |
150612 KB |
Output is correct |
11 |
Correct |
92 ms |
150868 KB |
Output is correct |
12 |
Correct |
83 ms |
147284 KB |
Output is correct |
13 |
Correct |
79 ms |
150356 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
15 ms |
85852 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
205 ms |
312400 KB |
Execution killed with signal 6 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
220 ms |
304212 KB |
Execution killed with signal 6 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
85 ms |
155428 KB |
Output is correct |
2 |
Correct |
97 ms |
155352 KB |
Output is correct |
3 |
Correct |
89 ms |
155476 KB |
Output is correct |
4 |
Correct |
98 ms |
155472 KB |
Output is correct |
5 |
Correct |
78 ms |
151892 KB |
Output is correct |
6 |
Correct |
16 ms |
88664 KB |
Output is correct |
7 |
Correct |
16 ms |
88412 KB |
Output is correct |
8 |
Correct |
16 ms |
88412 KB |
Output is correct |
9 |
Correct |
85 ms |
150776 KB |
Output is correct |
10 |
Correct |
85 ms |
150612 KB |
Output is correct |
11 |
Correct |
92 ms |
150868 KB |
Output is correct |
12 |
Correct |
83 ms |
147284 KB |
Output is correct |
13 |
Correct |
79 ms |
150356 KB |
Output is correct |
14 |
Incorrect |
15 ms |
85852 KB |
Output isn't correct |
15 |
Halted |
0 ms |
0 KB |
- |