#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 1e5+1, V = 102;
const ll INF = 1e17;
int n, lim, a[N];
ll ans, s[N], dp[N][V][2][2]; // vertex, cnt, dir = down, up, last
vector<int> g[N];
void dfs1(int u, int p) {
s[u] -= a[p];
dp[u][0][0][0] = 0;
dp[u][0][1][0] = 0;
dp[u][1][0][1] = s[u];
dp[u][1][1][1] = s[u];
for(int v: g[u]) if(v != p) {
dfs1(v, u);
for(int x = 0; x <= lim; x++) {
ll res[2][2] = {
{
max(dp[v][x][0][0], dp[v][x][0][1]),
(x > 0 ? max(dp[v][x-1][0][0], dp[v][x-1][0][1]) + s[u] : -INF) // a[v] - a[v] cancel
},
{
max(dp[v][x][1][0], dp[v][x][1][1] + a[u]),
(x > 0 ? max(dp[v][x-1][1][0], dp[v][x-1][1][1] + a[u]) + s[u] - a[v] : -INF)
}
};
for(int i = 0; i < 2; i++) for(int j = 0; j < 2; j++) {
dp[u][x][i][j] = max(res[i][j], dp[u][x][i][j]);
if(dp[u][x][i][j] < 0) dp[u][x][i][j] = -INF;
}
}
}
}
void dfs2(int u, int p) {
for(int v: g[u]) if(v != p) dfs2(v, u);
vector<ll> mx(lim+1);
ans = max(s[u] + a[p], ans);
for(int v: g[u]) if(v != p) {
for(int x = 0; x <= lim; x++) {
ll res = max(dp[v][x][1][0], dp[v][x][1][1] + a[u]);
res += max((x < lim ? mx[lim-x-1] + s[u] - a[v] + a[p] : 0), mx[lim-x]);
ans = max(res, ans);
}
for(int x = 0; x <= lim; x++) {
mx[x] = max(max(dp[v][x][0][0], dp[v][x][0][1]), mx[x]);
mx[x] = max((x > 0 ? mx[x-1] : 0), mx[x]);
}
}
reverse(g[u].begin(), g[u].end());
mx.clear();
mx.resize(lim+1);
for(int v: g[u]) if(v != p) {
for(int x = 0; x <= lim; x++) {
ll res = max(dp[v][x][1][0], dp[v][x][1][1] + a[u]);
res += max((x < lim ? mx[lim-x-1] + s[u] - a[v] + a[p] : 0), mx[lim-x]);
ans = max(res, ans);
}
for(int x = 0; x <= lim; x++) {
mx[x] = max(max(dp[v][x][0][0], dp[v][x][0][1]), mx[x]);
mx[x] = max((x > 0 ? mx[x-1] : 0), mx[x]);
}
}
}
int main() {
scanf("%d %d", &n, &lim);
for(int i = 1; i <= n; i++) scanf("%d", a+i);
for(int i = 1, u, v; i < n; i++) {
scanf("%d %d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
s[u] += a[v];
s[v] += a[u];
}
for(int v = 0; v <= n; v++)
for(int x = 0; x <= lim; x++)
for(int i = 0; i < 2; i++) for(int j = 0; j < 2; j++) dp[v][x][i][j] = -INF;
dfs1(1, 0);
dfs2(1, 0);
printf("%lld", ans);
}
Compilation message
chase.cpp: In function 'int main()':
chase.cpp:73:26: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
scanf("%d %d", &n, &lim);
^
chase.cpp:74:46: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
for(int i = 1; i <= n; i++) scanf("%d", a+i);
^
chase.cpp:76:25: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
scanf("%d %d", &u, &v);
^
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
0 ms |
324288 KB |
Output is correct |
2 |
Correct |
0 ms |
324288 KB |
Output is correct |
3 |
Correct |
0 ms |
324288 KB |
Output is correct |
4 |
Correct |
0 ms |
324288 KB |
Output is correct |
5 |
Incorrect |
0 ms |
324288 KB |
Output isn't correct |
6 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
0 ms |
324288 KB |
Output is correct |
2 |
Correct |
0 ms |
324288 KB |
Output is correct |
3 |
Correct |
0 ms |
324288 KB |
Output is correct |
4 |
Correct |
0 ms |
324288 KB |
Output is correct |
5 |
Incorrect |
0 ms |
324288 KB |
Output isn't correct |
6 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
569 ms |
335076 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
0 ms |
324288 KB |
Output is correct |
2 |
Correct |
0 ms |
324288 KB |
Output is correct |
3 |
Correct |
0 ms |
324288 KB |
Output is correct |
4 |
Correct |
0 ms |
324288 KB |
Output is correct |
5 |
Incorrect |
0 ms |
324288 KB |
Output isn't correct |
6 |
Halted |
0 ms |
0 KB |
- |