#include <bits/stdc++.h>
#define mp make_pair
#define X first
#define Y second
using namespace std;
typedef long long ll;
typedef pair <int, int> ii;
const int N = 1e5 + 1;
const int V = 101;
const ll INF = 1e18;
int n, k, a[N];
vector <int> adj[N];
ll s[N], dp0[N][V][2], dp1[N][V][2], res[N];
void readInput(){
cin >> n >> k;
for(int i = 1; i <= n; i++)
cin >> a[i];
for(int i = 1; i < n; i++){
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
}
void maximize(ll &x, const ll &y){
if (x < y)
x = y;
}
void dfs(int u, int p){
for(int v : adj[u])
if (v != p){
s[u] += a[v];
s[v] += a[u];
dfs(v, u);
}
for(int i = 0; i <= k; i++)
for(int mask = 0; mask < 2; mask++)
dp0[u][i][mask] = dp1[u][i][mask] = -INF;
for(int i = 0; i <= k; i++){
dp0[u][i][0] = dp1[u][i][0] = 0;
if (i > 0)
dp0[u][i][1] = dp1[u][i][1] = s[u];
}
for(int v : adj[u])
if (v != p){
for(int i = 0; i <= k; i++){
maximize(res[u], dp0[u][i][0] + dp1[v][k - i][0]);
maximize(res[u], dp0[u][i][1] + dp1[v][k - i][0]);
maximize(res[u], dp0[u][i][0] + dp1[v][k - i][1] - a[u]);
maximize(res[u], dp0[u][i][1] + dp1[v][k - i][1] - a[u]);
maximize(res[u], dp1[u][i][0] + dp0[v][k - i][0]);
maximize(res[u], dp1[u][i][1] + dp0[v][k - i][0] - a[v]);
maximize(res[u], dp1[u][i][0] + dp0[v][k - i][1]);
maximize(res[u], dp1[u][i][1] + dp0[v][k - i][1] - a[v]);
}
for(int i = 0; i <= k; i++){
maximize(dp0[u][i][0], dp0[v][i][0]);
maximize(dp0[u][i][0], dp0[v][i][1]);
maximize(dp1[u][i][0], dp1[v][i][0]);
maximize(dp1[u][i][0], dp1[v][i][1] - a[u]);
if (i > 0){
maximize(dp0[u][i][1], dp0[v][i - 1][0] + s[u] - a[v]);
maximize(dp0[u][i][1], dp0[v][i - 1][1] + s[u] - a[v]);
maximize(dp1[u][i][1], dp1[v][i - 1][0] + s[u]);
maximize(dp1[u][i][1], dp1[v][i - 1][1] + s[u] - a[u]);
}
}
}
maximize(res[u], dp0[u][k][0]);
maximize(res[u], dp0[u][k][1]);
maximize(res[u], dp1[u][k][0]);
maximize(res[u], dp1[u][k][1]);
}
void solve(){
if (k == 0){
printf("0");
return;
}
dfs(1, -1);
cout << *max_element(res + 1, res + 1 + n);
}
int main(){
readInput();
solve();
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
2 ms |
2796 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
140 ms |
131072 KB |
Execution killed with signal 9 (could be triggered by violating memory limits) |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
2 ms |
2796 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |