#include <bits/stdc++.h>
#define endl '\n'
using namespace std;
typedef long long ll;
const int MAXN = 1e5 + 10;
const ll MAXD = 1e18 + 10;
const ll MOD = 1e9 + 7;
struct matrix
{
int n;
vector < vector < ll > > m;
matrix()
{
n = 0;
m.clear();
}
matrix(int n)
{
this->n = n;
m.resize(n);
for(int i = 0; i < n; i++)
{
m[i].resize(n, 0LL);
m[i][i] = 1LL;
}
}
int size()
{
return n;
}
friend matrix operator*(matrix a, matrix b)
{
int sz = a.size();
matrix result(sz);
for(int k = 0; k < sz; k++)
{
for(int i = 0; i < sz; i++)
{
for(int j = 0; j < sz; j++)
{
result.m[i][j] = (result.m[i][j] + (a.m[i][k] * b.m[k][j]) % MOD) % MOD;
}
}
}
return result;
}
};
ll binpow(ll a, ll b)
{
ll result = 1;
while(b)
{
if(b & 1)
result = (result * a) % MOD;
a = (a * a) % MOD;
b /= 2;
}
return result;
}
matrix binpow(matrix a, ll b)
{
int sz = a.size();
matrix result(sz);
while(b)
{
if(b & 1)
result = result * a;
a = a * a;
b /= 2;
}
return result;
}
vector < ll > mult(vector < ll > a, matrix b)
{
vector < ll > result = {0, 0};
for(int i = 0; i < 2; i++)
{
for(int j = 0; j < 2; j++)
{
result[j] = (result[j] + b.m[i][j] * a[i]) % MOD;
}
}
return result;
}
ll n;
ll d;
vector < ll > adj[MAXN];
ll dp[MAXN];
ll dpt[MAXN];
ll dph[MAXN];
ll crit[MAXN];
ll win[MAXN];
ll lose[MAXN];
ll cntwin;
ll cntlose;
ll pot;
void read()
{
cin >> n >> d;
for(int i = 1; i <= n - 1; i++)
{
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
}
void dfs(int u, int par)
{
for(int v : adj[u])
{
if(v == par)
continue;
dfs(v, u);
if(!dpt[v])
dpt[u]++;
}
}
void find_crit(int u, int par)
{
if(!dpt[u])
crit[u] = 1;
for(int v : adj[u])
{
if(v == par)
continue;
find_crit(v, u);
if(dpt[v])
win[u] += crit[v];
else
lose[u] += crit[v];
if(dpt[u] == 1 && !dpt[v])
crit[u] += crit[v];
if(!dpt[u] && dpt[v])
crit[u] += crit[v];
}
}
void change(int u, int v)
{
if(!dpt[v])
{
dpt[u]--;
lose[u] -= crit[v];
}
else
{
win[u] -= crit[v];
}
if(dpt[u] > 1)
crit[u] = 0;
else if(dpt[u] == 1)
crit[u] = lose[u];
else
crit[u] = win[u] + 1;
if(!dpt[u])
{
dpt[v]++;
lose[v] += crit[u];
}
else
{
win[v] += crit[u];
}
if(dpt[v] > 1)
crit[v] = 0;
else if(dpt[v] == 1)
crit[v] = lose[v];
else
crit[v] = win[v] + 1;
}
void reroot(int u, int par)
{
dp[u] = dpt[u];
dph[u] = crit[u];
for(int v : adj[u])
{
if(v == par)
continue;
change(u, v);
reroot(v, u);
change(v, u);
}
}
void find_win_lose()
{
dfs(1, 0);
find_crit(1, 0);
reroot(1, 0);
for(int i = 1; i <= n; i++)
{
if(!dp[i])
{
pot++;
cntlose = (cntlose + dph[i]) % MOD;
}
else
{
cntwin = (cntwin + dph[i]) % MOD;
}
}
}
void solve()
{
matrix mat(2);
mat.m[0][0] = cntwin - cntlose;
mat.m[0][1] = 0LL;
mat.m[1][0] = 1LL;
mat.m[1][1] = (n * n) % MOD;
vector < ll > tmp = {pot, pot * ((n * n) % MOD) % MOD};
tmp = mult(tmp, binpow(mat, d - 1));
ll ans = 0;
if(dp[1])
ans = (dph[1] * tmp[0]) % MOD;
else
ans = (binpow(n, 2LL * d) - (dph[1], tmp[0]) % MOD + MOD) % MOD;
ans = (binpow(n, 2LL * d) - ans + MOD) % MOD;
cout << ans << endl;
}
int main()
{
ios_base :: sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
read();
find_win_lose();
solve();
return 0;
}
Compilation message
startrek.cpp: In function 'void solve()':
startrek.cpp:266:43: warning: left operand of comma operator has no effect [-Wunused-value]
266 | ans = (binpow(n, 2LL * d) - (dph[1], tmp[0]) % MOD + MOD) % MOD;
| ~~~~~^
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
5212 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
5208 KB |
Output is correct |
2 |
Correct |
1 ms |
5212 KB |
Output is correct |
3 |
Correct |
1 ms |
5212 KB |
Output is correct |
4 |
Correct |
1 ms |
5212 KB |
Output is correct |
5 |
Correct |
1 ms |
5212 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
5212 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
5212 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
5212 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
5212 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
5212 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
5212 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |