#include<iostream>
#include<algorithm>
#include<iomanip>
#include<cmath>
#include<cstring>
#include<vector>
#include<queue>
#include<stack>
#include<tuple>
#include<set>
#include<map>
#include<random>
#include<chrono>
using namespace std;
const int MOD=1e9+7;
const int MAX_N=1e5+10;
int n;
long long d;
vector<int>g[MAX_N];
int add(int a,int b)
{
int res=(a+b)%(MOD);
return res;
}
int sub(int a,int b)
{
int res=a-b+(MOD);
res%=(MOD);
return res;
}
int mul(int a,int b)
{
long long res=(1LL*a*b);
res%=(MOD);
return res;
}
int st(int a,long long b)
{
if(b==0)return 1;
int res=st(a,b/2);
res=mul(res,res);
if(b%2==1)res=mul(res,a);
return res;
}
int del(int a,int b)
{
int res=mul(a,st(b,MOD-2));
return res;
}
int S(int A,int Q,long long N)
{
int up=mul(A,sub(st(Q,N),1));
int down=sub(Q,1);
return del(up,down);
}
bool statdown[MAX_N];
bool statup[MAX_N];
bool stat[MAX_N];
int cnt0[MAX_N];
int cntzero;
int dpdown[MAX_N];
int dpup[MAX_N];
int sum[MAX_N];
int crit[MAX_N];
vector<int>zeroes[MAX_N];
void dfsdown(int u,int par)
{
for(int v:g[u])
{
if(v==par)continue;
dfsdown(v,u);
if(statdown[v]==0){cnt0[u]++;zeroes[u].push_back(v);}
}
for(int v:g[u])
{
if(v==par)continue;
if(cnt0[u]==0 or (cnt0[u]==1 && statdown[v]==0))dpdown[u]+=dpdown[v];
sum[u]+=dpdown[v];
}
if(cnt0[u]==0){statdown[u]=0;dpdown[u]++;}
}
void dfsup(int u,int par)
{
if(statup[u]==1 && cnt0[u]==0){stat[u]=0;cntzero++;}
for(int v:g[u])
{
if(v==par)continue;
if(statup[u]==1 && (cnt0[u]==0 or (cnt0[u]==1 && statdown[v]==0)))statup[v]=0;
int onlyup=0;
if(cnt0[u]==0 or (cnt0[u]==1 && statdown[v]==0))onlyup+=dpup[u];
if(statup[u]==1)
{
int cntzerodown=cnt0[u]-(statdown[v]==0);
if(cntzerodown==1)
{
int spec=zeroes[u][0];
if(spec==v)spec=zeroes[u][1];
onlyup+=dpdown[spec];
}
else if(cntzerodown==0)
{
onlyup+=(sum[u]-dpdown[v]+1);
}
}
dpup[v]=onlyup;
dfsup(v,u);
}
}
void precompute()
{
for(int i=1;i<=n;i++)
{
statup[i]=1;
statdown[i]=1;
stat[i]=1;
}
dfsdown(1,0);
dfsup(1,0);
for(int i=1;i<=n;i++)
{
if(cnt0[i]==0)crit[i]+=dpup[i];
if(statup[i]==1)
{
crit[i]+=dpdown[i];
}
}
}
void solve()
{
int E=0;
for(int i=1;i<=n;i++)
{
if(stat[i])E=add(E,crit[i]);
else E=sub(E,crit[i]);
}
int sum=S(mul(cntzero,st(n,2*(d-1))),del(E,st(n,2)),d);
int ans=mul(crit[1],sum);
if(stat[1]==0)ans=sub(st(n,2*d),ans);
ans=sub(st(n,2*d),ans);
cout<<ans<<"\n";
}
signed main ()
{
ios_base::sync_with_stdio(0);
cin.tie(NULL);
cout.tie(NULL);
cin>>n>>d;
for(int i=1;i<n;i++)
{
int u,v;
cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
precompute();
solve();
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |