#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 1e5;
const int MAXV = 101;
ll dp[2][MAXN][MAXV]; // away or towards
int val[MAXN];
ll adj_sum[MAXN];
vector<int> edges[MAXN];
bool vis[MAXN];
int sz[MAXN];
int n, v;
ll ans = 0;
void get_sz(int cur, int p = -1) {
sz[cur] = 1;
for (int nxt: edges[cur]) {
if (nxt == p || vis[nxt]) continue;
get_sz(nxt, cur);
sz[cur] += sz[nxt];
if (p != -1) {
for (int j = 1; j <= v; j++) {
dp[0][cur][j] = max(dp[0][cur][j], max(dp[0][nxt][j], dp[0][nxt][j-1]+adj_sum[cur]-val[p]));
dp[1][cur][j] = max(dp[1][cur][j], max(dp[1][nxt][j], dp[1][nxt][j-1]+adj_sum[cur]-val[nxt]));
}
}
}
if (p != -1) {
for (int j = 1; j <= v; j++) {
dp[0][cur][j] = max(dp[0][cur][j], adj_sum[cur]-val[p]);
dp[1][cur][j] = max(dp[1][cur][j], adj_sum[cur]);
}
}
}
void decomp(int cur) {
get_sz(cur);
fill(dp[1][cur], dp[1][cur]+v+1, adj_sum[cur]);
for (int nxt: edges[cur]) {
if (vis[nxt]) continue;
for (int i = 0; i <= v; i++) {
ans = max(ans, dp[1][cur][i]+dp[0][nxt][v-i]);
if (i) dp[1][cur][i] = max(dp[1][cur][i], max(dp[1][nxt][i], dp[1][nxt][i-1]+adj_sum[cur]-val[nxt]));
}
}
fill(dp[1][cur], dp[1][cur]+v+1, adj_sum[cur]);
for (int j = edges[cur].size()-1; j >= 0; j--) {
int nxt = edges[cur][j];
if (vis[nxt]) continue;
for (int i = 0; i <= v; i++) {
ans = max(ans, dp[1][cur][i]+dp[0][nxt][v-i]);
if (i) dp[1][cur][i] = max(dp[1][cur][i], max(dp[1][nxt][i], dp[1][nxt][i-1]+adj_sum[cur]-val[nxt]));
}
}
ans = max(ans, dp[1][cur][v]);
int p = -1;
int tot_sz = sz[cur];
bool found = 1;
while (found) {
found = 0;
for (int nxt: edges[cur]) {
if (nxt == p || vis[nxt]) continue;
if (sz[nxt] > tot_sz/2) {
found = 1;
p = cur;
cur = nxt;
break;
}
}
}
vis[cur] = 1;
for (int nxt: edges[cur]) if (!vis[nxt]) decomp(nxt);
// cerr << ans << "\n";
}
int main() {
ios_base::sync_with_stdio(false); cin.tie(NULL);
cin >> n >> v;
for (int i = 0; i < n; i++) cin >> val[i];
for (int i = 0; i < n-1; i++) {
int a, b; cin >> a >> b;
a--; b--;
edges[a].push_back(b);
edges[b].push_back(a);
adj_sum[a] += val[b];
adj_sum[b] += val[a];
}
decomp(0);
cout << ans << "\n";
}
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
2 ms |
2684 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2644 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
2 ms |
2684 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2644 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
983 ms |
172412 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
2 ms |
2684 KB |
Output is correct |
2 |
Incorrect |
2 ms |
2644 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |