This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
#define N 100005
using namespace std;
const int mod = 1e9 + 7;
long long binpow(long long a,long long b){
long long ret = 1;
while(b){
if(b & 1)
ret = ret * a %mod;
a = a * a %mod;
b >>=1;
}
return ret;
}
struct node{
long long a,b,sz,sz2;
node(){
a = b = sz = sz2 = 0;
}
};
struct matrix{
long long a[3][3];
matrix(){
for(int i = 0;i<3;i++){
for(int j = 0;j<3;j++){
a[i][j] = 0;
}
}
}
matrix operator*(matrix other){
matrix ret;
for(int i = 0;i<3;i++){
for(int j =0;j<3;j++){
for(int k = 0;k<3;k++){
ret.a[i][k] = (ret.a[i][k] + a[i][j]*other.a[j][k])%mod;
}
}
}
return ret;
}
};
matrix binexpo(matrix a,long long b){
matrix ret;
for(int i = 0;i<3;i++)
ret.a[i][i] = 1;
while(b){
if(b&1)
ret = ret * a;
a = a * a;
b>>=1;
}
return ret;
}
vector<int> adj[N];
long long ans = 0;
long long n,d;
int win[N];
int sub[N];
int top[N];
int len[N];
int sum = 0;
int timer = 1;
int tin[N],tout[N];
bool dfs(int v,int par){
tin[v] = timer++;
sub[v] = 0;
for(auto u:adj[v]){
if(u == par)continue;
if(dfs(u,v) == 0){
win[v] = 1;
sub[v] = 1;
}
}
tout[v] = timer - 1;
len[v] = tout[v] - tin[v] + 1;
return sub[v];
}
void dfs2(int v,int par){
int cnt = !top[v];
for(auto u:adj[v]){
if(u == par)continue;
cnt += !sub[u];
}
for(auto u:adj[v]){
if(u == par)continue;
cnt -= !sub[u];
if(!cnt){
win[u] = 1;
top[u] = 0;
}
cnt += !sub[u];
dfs2(u,v);
}
}
node subval[N];
node val[N];
node topval[N];
void merge(node &a,node b){
a.a += b.b;
a.b += b.a;
a.sz += b.sz2;
a.sz2 += b.sz;
a.sz %= mod;
a.sz2 %= mod;
}
void antimerge(node &a,node b){
a.a -= b.b;
a.b -= b.a;
a.sz -= b.sz2;
a.sz2 -= b.sz;
a.sz = (a.sz + mod)%mod;
a.sz2 = (a.sz2 + mod)%mod;
}
void dfs3(int v,int par){
vector<int> places;
subval[v] = node();
for(auto u:adj[v]){
if(u == par)continue;
if(!sub[u]){
places.push_back(u);
}
dfs3(u,v);
}
if(places.empty()){
subval[v].b++;
for(auto u:adj[v]){
if(u == par)continue;
merge(subval[v],subval[u]);
}
}
if(places.size() == 1){
subval[v].sz += n * (len[v] - len[places[0]]);
subval[v].sz %= mod;
merge(subval[v],subval[places[0]]);
}
if(places.size() > 1){
subval[v].sz += n * len[v];
subval[v].sz %= mod;
}
}
void dfs4(int v,int par){
vector<int> places;
if(!top[v])
places.push_back(par);
val[v] = node();
for(auto u:adj[v]){
if(u == par)continue;
if(!sub[u]){
places.push_back(u);
}
}
if(places.empty()){
val[v].b++;
merge(val[v],topval[v]);
for(auto u:adj[v]){
if(u == par)continue;
merge(val[v],subval[u]);
}
}
if(places.size() == 1){
if(places[0] == par){
val[v].sz += n * len[v];
val[v].sz %= mod;
merge(val[v],topval[v]);
}
else{
val[v].sz += n * (n - len[places[0]]);
val[v].sz %= mod;
merge(val[v],subval[places[0]]);
}
}
if(places.size() > 1){
val[v].sz += n * n;
val[v].sz %= mod;
}
node sum;
merge(sum,topval[v]);
for(auto u:adj[v]){
if(u == par)continue;
merge(sum,subval[u]);
}
for(auto u:adj[v]){
if(u == par)continue;
antimerge(sum,subval[u]);
node tmp = sum;
int x = places.size() - !sub[u];
if(x == 0){
tmp.b++;
}
if(x == 1){
tmp = node();
int pos = places[0];
if(pos == u)
pos = places[1];
if(pos == par){
tmp.sz += n * ( len[v] - len[u]);
tmp.sz %= mod;
merge(tmp,topval[v]);
}
else{
tmp.sz += n * (n - len[pos] - len[u]);
tmp.sz %= mod;
merge(tmp,subval[pos]);
}
}
if(x > 1){
tmp = node();
tmp.sz += n * (n - len[u]);
tmp.sz %= mod;
}
topval[u] = tmp;
merge(sum,subval[u]);
dfs4(u,v);
}
}
void solve(){
cin >> n >> d;
for(int i = 1;i<n;i++){
int u,v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
vector<node> v;
long long winstate = 0;
for(int i = 1;i<=n;i++){
win[i] = 0;
sub[i] = 0;
top[i] = 1;
}
dfs(1,0);
dfs2(1,0);
dfs3(1,0);
dfs4(1,0);
for(int x = 1;x<=n;x++){
v.push_back(val[x]);
}
long long suma = 0,sumb = 0,sumsz = 0;
for(auto u:v){
suma += u.a;
sumb += u.b;
sumsz += u.sz;
suma %= mod;
sumb %= mod;
sumsz %= mod;
}
for(int i = 1;i<=n;i++){
winstate += win[i];
}
matrix single;
single.a[0][0] = (suma - sumb +mod)%mod;
single.a[1][0] = (sumb)%mod;
single.a[2][0] = (sumsz)%mod;
single.a[1][1] = n*n%mod;
single.a[2][2] = n*n%mod;
matrix total = binexpo(single,d-1);
long long val = 0;
val = (val + total.a[0][0] * winstate)%mod;
val = (val + total.a[1][0] * n)%mod;
val = (val + total.a[2][0] * 1)%mod;
winstate = val;
/*
for(int i = 0;i<d-1;i++){
winstate = (winstate * (suma - sumb + mod) + binpow(n,2*i+1)*sumb + binpow(n,2*i)*sumsz)%mod;
}*/
winstate = (winstate * (v[0].a - v[0].b + mod) + binpow(n,2*d-1)*v[0].b + binpow(n,2*d-2)*v[0].sz)%mod;
cout << winstate;
}
int main(){
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
#ifdef Local
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
#endif
int t = 1;
//cin >> t;
while(t--){
solve();
}
#ifdef Local
cout << endl << fixed << setprecision(2) << 1000.0 * clock() / CLOCKS_PER_SEC << " milliseconds.";
#endif
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |