#include <bits/stdc++.h>
#define mp make_pair
#define X first
#define Y second
using namespace std;
typedef long long ll;
typedef pair <int, int> ii;
const int N = 1e5 + 1;
const int V = 101;
const ll INF = 1e18;
int n, k, a[N];
vector <int> adj[N];
ll s[N], dp0[N][V][2], dp1[N][V][2], res[N];
void readInput(){
scanf("%d %d", &n, &k);
for(int i = 1; i <= n; i++)
scanf("%d", &a[i]);
for(int i = 1; i < n; i++){
int u, v;
scanf("%d %d", &u, &v);
adj[u].push_back(v);
adj[v].push_back(u);
}
}
void maximize(ll &x, const ll &y){
if (x < y)
x = y;
}
void dfs(int u, int p){
for(int v : adj[u])
if (v != p){
s[u] += a[v];
s[v] += a[u];
dfs(v, u);
}
for(int i = 0; i <= k; i++)
for(int mask = 0; mask < 2; mask++)
dp0[u][i][mask] = dp1[u][i][mask] = -INF;
for(int i = 0; i <= k; i++){
dp0[u][i][0] = dp1[u][i][0] = 0;
if (i > 0)
dp0[u][i][1] = dp1[u][i][1] = s[u];
}
for(int v : adj[u])
if (v != p){
for(int i = 0; i <= k; i++){
maximize(res[u], dp0[u][i][0] + dp1[v][k - i][0]);
maximize(res[u], dp0[u][i][1] + dp1[v][k - i][0]);
maximize(res[u], dp0[u][i][0] + dp1[v][k - i][1] - a[u]);
maximize(res[u], dp0[u][i][1] + dp1[v][k - i][1] - a[u]);
maximize(res[u], dp1[u][i][0] + dp0[v][k - i][0]);
maximize(res[u], dp1[u][i][1] + dp0[v][k - i][0] - a[v]);
maximize(res[u], dp1[u][i][0] + dp0[v][k - i][1]);
maximize(res[u], dp1[u][i][1] + dp0[v][k - i][1] - a[v]);
}
for(int i = 0; i <= k; i++){
maximize(dp0[u][i][0], dp0[v][i][0]);
maximize(dp0[u][i][0], dp0[v][i][1]);
maximize(dp1[u][i][0], dp1[v][i][0]);
maximize(dp1[u][i][0], dp1[v][i][1] - a[u]);
if (i > 0){
maximize(dp0[u][i][1], dp0[v][i - 1][0] + s[u] - a[v]);
maximize(dp0[u][i][1], dp0[v][i - 1][1] + s[u] - a[v]);
maximize(dp1[u][i][1], dp1[v][i - 1][0] + s[u]);
maximize(dp1[u][i][1], dp1[v][i - 1][1] + s[u] - a[u]);
}
}
}
maximize(res[u], dp0[u][k][0]);
maximize(res[u], dp0[u][k][1]);
maximize(res[u], dp1[u][k][0]);
maximize(res[u], dp1[u][k][1]);
}
void solve(){
if (k == 0){
printf("0");
return;
}
dfs(1, -1);
printf("%I64d", *max_element(res + 1, res + 1 + n));
}
int main(){
readInput();
solve();
return 0;
}
Compilation message
chase.cpp: In function 'void solve()':
chase.cpp:97:17: warning: format '%d' expects argument of type 'int', but argument 2 has type 'long long int' [-Wformat=]
97 | printf("%I64d", *max_element(res + 1, res + 1 + n));
| ~~~~^ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
| | |
| int long long int
| %I64lld
chase.cpp: In function 'void readInput()':
chase.cpp:22:10: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
22 | scanf("%d %d", &n, &k);
| ~~~~~^~~~~~~~~~~~~~~~~
chase.cpp:24:14: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
24 | scanf("%d", &a[i]);
| ~~~~~^~~~~~~~~~~~~
chase.cpp:27:14: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
27 | scanf("%d %d", &u, &v);
| ~~~~~^~~~~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2796 KB |
Output is correct |
2 |
Correct |
2 ms |
2796 KB |
Output is correct |
3 |
Incorrect |
2 ms |
2668 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2796 KB |
Output is correct |
2 |
Correct |
2 ms |
2796 KB |
Output is correct |
3 |
Incorrect |
2 ms |
2668 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
468 ms |
334016 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2796 KB |
Output is correct |
2 |
Correct |
2 ms |
2796 KB |
Output is correct |
3 |
Incorrect |
2 ms |
2668 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |