#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ll long long
#define pii pair<int, int>
#define F first
#define S second
#define PB push_back
#define ALL(x) x.begin(), x.end()
#define FOR(i, n) for (int i = 0; i < n; i++)
#define NYOOM ios::sync_with_stdio(0); cin.tie(0);
#define endl '\n'
const int INF = 1e9 + 7;
const ll LLINF = 1ll<<60;
const int maxn = 1e5 + 10;
int n, D, toggle[maxn] = {0}, cnt[maxn] = {0}, win_sum[maxn] = {0}, lose_sum[maxn] = {0},
a = 0, b = 0, c = 0, d = 0, cnt_win = 0, cnt_lose = 0;
bool lose[maxn] = {0};
vector<int> adj[maxn];
// u = #win, v = #lose
// u' = a * u + b * v
// v' = c * u + d * v
void dfs(int v, int p){
for (auto u : adj[v]){
if (u == p) continue;
dfs(u, v);
if (lose[u]) cnt[v]++, lose_sum[v] = (lose_sum[v] + toggle[u]) % INF;
else win_sum[v] = (win_sum[v] + toggle[u]) % INF;
}
if (cnt[v] == 0) lose[v] = true, toggle[v] = 1 + win_sum[v];
else if (cnt[v] == 1) lose[v] = false, toggle[v] = lose_sum[v];
else lose[v] = false, toggle[v] = 0;
}
void add(int v, int u){
if (lose[u]) lose_sum[v] = (lose_sum[v] + toggle[u]) % INF, cnt[v]++;
else win_sum[v] = (win_sum[v] + toggle[u]) % INF;
if (cnt[v] == 0) lose[v] = true, toggle[v] = 1 + win_sum[v];
else if (cnt[v] == 1) lose[v] = false, toggle[v] = lose_sum[v];
else lose[v] = false, toggle[v] = 0;
}
void remove(int v, int u){
if (lose[u]) lose_sum[v] = (lose_sum[v] - toggle[u] + INF) % INF, cnt[v]--;
else win_sum[v] = (win_sum[v] - toggle[u] + INF) % INF;
if (cnt[v] == 0) lose[v] = true, toggle[v] = 1 + win_sum[v];
else if (cnt[v] == 1) lose[v] = false, toggle[v] = lose_sum[v];
else lose[v] = false, toggle[v] = 0;
}
void reroot(int v, int p){
if (lose[v]){
cnt_lose++;
c = (c + n) % INF; // can connect any node to win state and keep node v at lose
d = (d + (n - toggle[v])) % INF;
b = (b + toggle[v]) % INF; // toggling will turn to win
}
else{
cnt_win++;
a = (a + n) % INF;
b = (b + (n - toggle[v])) % INF;
d = (d + toggle[v]) % INF;
}
for (auto u : adj[v]){
if (u == p) continue;
remove(v, u); add(u, v);
reroot(u, v);
remove(u, v); add(v, u);
}
}
vector<int> combine(vector<int> s1, vector<int> s2){
int a1 = s1[0], b1 = s1[1], c1 = s1[2], d1 = s1[3];
int a2 = s2[0], b2 = s2[1], c2 = s2[2], d2 = s2[3];
vector<int> re(4);
re[0] = (a1 * a2 + c1 * b2) % INF;
re[1] = (b1 * a2 + d1 * b2) % INF;
re[2] = (a1 * c2 + c1 * d2) % INF;
re[3] = (b1 * c2 + d1 * d2) % INF;
return re;
}
vector<int> mpow(vector<int> s, int e){
vector<int> re = {1, 0, 0, 1};
while (e){
if (e % 2) re = combine(re, s);
s = combine(s, s);
e /= 2;
}
return re;
}
signed main(){
NYOOM;
cin >> n >> D;
FOR(i, n - 1){
int u, v; cin >> u >> v;
adj[u].PB(v); adj[v].PB(u);
}
dfs(1, 1);
bool start_lose = lose[1], start_toggle = toggle[1];
reroot(1, 1);
vector<int> s = mpow({a, b, c, d}, D - 1);
FOR(i, n){
if (lose[i + 1]) cnt_lose++;
else cnt_win++;
}
int temp = cnt_win;
cnt_win = (s[0] * cnt_win + s[1] * cnt_lose) % INF;
cnt_lose = (s[2] * temp + s[3] * cnt_lose) % INF;
if (start_lose){
cout << (start_toggle * cnt_lose) % INF;
}
else{
int total = (n * cnt_win) % INF;
total = (total + (n - start_toggle) * cnt_lose) % INF;
cout << total << endl;
}
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2644 KB |
Output is correct |
2 |
Incorrect |
3 ms |
2772 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
2 ms |
2644 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2680 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2644 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2680 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2644 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2680 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2644 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2680 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2644 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2680 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2644 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2644 KB |
Output is correct |
2 |
Incorrect |
3 ms |
2772 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |