#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e5+10, maxk = 110;
int n, k, vet[maxn], vis[maxn];
ll dp[maxn][maxk], dp2[maxn][maxk], ant[maxn][maxk], ans;
vector <int> graph[maxn];
ll solve(int u, int f = 0, int c = k)
{
if(c == 0) return 0;
if(dp[u][c] > 0) return dp[u][c];
for(int v: graph[u])
{
if(v == f) continue;
ll val = max(solve(v, u, c-1) + vis[v] - vet[u], solve(v, u, c));
if(val > dp[u][c])
{
dp2[u][c] = dp[u][c];
dp[u][c] = val;
}
else if(val > dp2[u][c]) dp2[u][c] = val;
}
return dp[u][c];
}
void rotate(int u, int f = 0)
{
for(int i = 1 ; i <= k ; ++i)
{
ans = max({ans, dp[u][i], dp[u][i-1] + vis[u]});
}
for(int v: graph[u])
{
if(v == f) continue;
for(int i = 1 ; i <= k ; ++i)
{
ant[u][i] = dp[u][i];
ll val = max(dp[v][i], dp[v][i-1] + vis[v] - vet[u]);
if(val == dp[u][i])
{
dp[u][i] = dp2[u][i];
val = max(dp[u][i], dp[u][i-1] + vis[u] - vet[v]);
if(val > dp[v][i]) dp2[v][i] = dp[v][i], dp[v][i] = val;
else if(val > dp2[v][i]) dp2[v][i] = val;
}
}
rotate(v, u);
for(int i = 1 ; i <= k ; ++i) dp[u][i] = ant[u][i];
}
}
int main()
{
cin >> n >> k;
for(int i = 1 ; i <= n ; ++i) cin >> vet[i];
for(int i = 1 ; i < n ; ++i)
{
int u, v;
cin >> u >> v;
graph[u].push_back(v);
graph[v].push_back(u);
}
for(int i = 1 ; i <= n ; ++i)
{
for(int v: graph[i])
{
vis[i] += vet[v];
}
}
solve(1);
rotate(1);
cout << ans << '\n';
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
2688 KB |
Output is correct |
2 |
Incorrect |
4 ms |
2688 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
2688 KB |
Output is correct |
2 |
Incorrect |
4 ms |
2688 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
537 ms |
143820 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
2688 KB |
Output is correct |
2 |
Incorrect |
4 ms |
2688 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |