#include<bits/stdc++.h>
using namespace std;
#define int long long
#define inf (int)2e18
#define nl '\n'
#define vi vector<int>
#define vvi vector<vi>
const int N = 1e5+1, K = 2, mod = 1e9+7;
vector<int> g[N];
int ls[N], lc[N], cr[N], sl[N], sw[N];
int lsr[N], lcr[N], crr[N], slr[N], swr[N];
void dfs1(int v, int p){
for(int ch : g[v]){
if(ch == p) continue;
dfs1(ch, v);
lc[v] += ls[ch];
if(ls[ch]) sl[v] += cr[ch];
else sw[v] += cr[ch];
}
ls[v] = (lc[v] == 0);
cr[v] = ls[v];
if(lc[v] == 0) cr[v] += sw[v];
if(lc[v] == 1) cr[v] += sl[v];
}
void dfs2(int v, int p){
for(int ch : g[v]){
if(ch == p) continue;
int lup = lcr[v] - ls[ch];
lcr[ch] = lc[ch] + (lup == 0);
lsr[ch] = (lcr[ch] == 0);
int rup = (lup == 0);
if(lup == 0) rup += swr[v] - (!ls[ch] ? cr[ch] : 0);
if(lup == 1) rup += slr[v] - (ls[ch] ? cr[ch] : 0);
swr[ch] = sw[ch];
slr[ch] = sl[ch];
if(lup == 0) slr[ch] += rup;
else swr[ch] += rup;
crr[ch] = lsr[ch];
if(lcr[ch] == 0) crr[ch] += swr[ch];
if(lcr[ch] == 1) crr[ch] += slr[ch];
dfs2(ch, v);
}
}
vvi mult(vvi& m1, vvi& m2){
vvi m(K, vi(K, 0));
for(int i = 0; i < K; i++){
for(int j = 0; j < K; j++){
for(int k = 0; k < K; k++){
m[i][j] += m1[i][k] * m2[k][j];
m[i][j] %= mod;
}
}
}
return m;
}
vvi exp(vvi& m, int p){
vvi res(K, vi(K, 0));
for(int i = 0; i < K; i++) res[i][i] = 1;
while(p > 0){
if(p % 2) res = mult(res, m);
m = mult(m, m);
p /= 2;
}
return res;
}
int pwr(int x, int p){
int res = 1;
while(p > 0){
if(p % 2) res = res * x % mod;
x = x * x % mod;
p /= 2;
}
return res;
}
void solve(){
int n, d;
cin>>n>>d;
for(int i = 1; i < n; i++){
int x, y;
cin>>x>>y;
g[x].push_back(y);
g[y].push_back(x);
}
dfs1(1, 0);
lsr[1] = ls[1];
lcr[1] = lc[1];
crr[1] = cr[1];
slr[1] = sl[1];
swr[1] = sw[1];
dfs2(1, 0);
int l = 0, c = 0;
for(int i = 1; i <= n; i++){
if(lsr[i]){
l++;
c = (c - crr[i] + mod) % mod;
}
else{
c = (c + crr[i]) % mod;
}
}
int n2 = n * n % mod;
vvi m = {{n2, 0}, {l, c}};
m = exp(m, d-1);
int cnt = (m[1][0] * n2 + m[1][1] * l) % mod;
cout<<(lsr[1] ? crr[1] * cnt : pwr(n, 2*d) - crr[1] * cnt % mod + mod) % mod;
}
signed main(){
ios_base::sync_with_stdio(0);
cin.tie(NULL);cout.tie(NULL);
int t = 1;
//cin>>t;
while(t--) solve();
return 0;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |