#include <algorithm>
#include <iostream>
#include <numeric>
#include <cassert>
#include <vector>
#include <random>
#include <queue>
#include <stack>
#include <set>
typedef long long llong;
const int MAXN = 100000 + 10;
const int MOD = 1e9 + 7;
const int INF = 2e9;
int n, d;
int sumL;
int sz[MAXN];
llong dp[MAXN];
bool win[MAXN];
int cntChildLoses[MAXN];
std::vector <int> g[MAXN];
llong power(int num, int base)
{
if (base == 0) return 1;
if (base & 1) return (num * power(num, base - 1)) % MOD;
llong res = power(num, base >> 1);
return (res * res) % MOD;
}
bool calcSubtreeWin(int node, int par)
{
for (const int &u : g[node])
{
if (u == par)
{
continue;
}
cntChildLoses[node] += !calcSubtreeWin(u, node);
}
return cntChildLoses[node];
}
void calcUpWin(int node, int par)
{
if (par != 0 && cntChildLoses[par] - cntChildLoses[node] > 0)
{
cntChildLoses[node]++;
}
for (const int &u : g[node])
{
if (u == par)
{
continue;
}
calcUpWin(u, node);
}
}
int dfs(int node, int par)
{
sz[node] = 1;
for (const int &u : g[node])
{
if (u == par)
{
continue;
}
sz[node] += dfs(u, node);
}
return sz[node];
}
int currTry;
int calc(int node, int par)
{
int cnt = 0;
int which = -1;
for (const int &u : g[node])
{
if (u == par)
{
continue;
}
cnt += !win[u];
if (!win[u]) which = u;
}
if (cnt >= 2)
{
dp[node] = ((1LL * (sz[node] * n) % MOD) * power(n, 2 * (currTry - 1))) % MOD;
} else if (cnt == 0)
{
dp[node] = sumL;
if (dp[node] < 0) dp[node] += MOD;
for (const int &u : g[node])
{
if (u == par)
{
continue;
}
dp[node] += calc(u, node);
if (dp[node] >= MOD) dp[node] -= MOD;
}
} else
{
dp[node] = ((1LL * ((sz[node] - sz[which]) * n) % MOD) * power(n, 2 * (currTry - 1))) % MOD;
dp[node] += calc(which, node);
if (dp[node] >= MOD) dp[node] -= MOD;
}
assert(dp[node] >= 0 && dp[node] < MOD);
return ((((1LL * (sz[node] * n) % MOD) * power(n, 2 * (currTry - 1))) % MOD) - dp[node] + MOD) % MOD;
}
void solve()
{
calcSubtreeWin(1, 0);
calcUpWin(1, 0);
sumL = 0;
for (int i = 1 ; i <= n ; ++i)
{
win[i] = (cntChildLoses[i] > 0);
sumL += 1 - win[i];
}
for (int i = 1 ; i < d ; ++i)
{
currTry = i;
int newSumL = 0;
for (int j = 1 ; j <= n ; ++j)
{
dfs(j, 0);
newSumL += calc(j, 0);
if (newSumL >= MOD) newSumL -= MOD;
}
newSumL = power(n, 2 * i) - newSumL;
if (newSumL < 0) newSumL += MOD;
sumL = newSumL;
}
dfs(1, 0);
currTry = d;
std::cout << (power(n, 2 * d) - calc(1, 0) + MOD) % MOD << '\n';
}
void input()
{
std::cin >> n >> d;
for (int i = 1 ; i < n ; ++i)
{
int u, v;
std::cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
}
void fastIOI()
{
std::ios_base :: sync_with_stdio(0);
std::cout.tie(nullptr);
std::cin.tie(nullptr);
}
int main()
{
fastIOI();
input();
solve();
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2652 KB |
Output is correct |
2 |
Execution timed out |
1087 ms |
2652 KB |
Time limit exceeded |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
2652 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2652 KB |
Output is correct |
2 |
Incorrect |
1 ms |
2808 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2652 KB |
Output is correct |
2 |
Incorrect |
1 ms |
2808 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2652 KB |
Output is correct |
2 |
Incorrect |
1 ms |
2808 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2652 KB |
Output is correct |
2 |
Incorrect |
1 ms |
2808 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2652 KB |
Output is correct |
2 |
Incorrect |
1 ms |
2808 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2652 KB |
Output is correct |
2 |
Execution timed out |
1087 ms |
2652 KB |
Time limit exceeded |
3 |
Halted |
0 ms |
0 KB |
- |