#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define int ll
#define pii pair<int, int>
#define ff first
#define ss second
#define pb push_back
#define mp make_pair
int k, ans = 0, p[100005], S[100005];
vector<int> adj[100005];
pii dp1[100005][105][2], dp2[100005][105][2];
void dfs(int u, int par = 1){
for(auto v:adj[u]){
if(v == par) continue;
dfs(v, u);
for(int i = 1; i <= k; ++i){
pii p1 = mp(max(dp1[v][i][0].ff, dp1[v][i - 1][0].ff + S[v] - p[u]), v);
pii p2 = mp(max(dp2[v][i][0].ff, dp2[v][i - 1][0].ff + S[u] - p[v]), v);
for(auto j : {0, 1}){
if(p1 >= dp1[u][i][j]) swap(p1, dp1[u][i][j]);
if(p2 >= dp2[u][i][j]) swap(p2, dp2[u][i][j]);
}
}
}
}
void dfs2(int u, int p = 1){
for(int v:adj[u]){
if(v == p) continue;
for(int i = 0; i <= k; ++i){
pii p1 = dp1[v][i][0], p2 = dp1[v][i][1];
pii p3 = dp2[u][k - i][0], p4 = dp2[u][k - i][1];
int t = (p3.ss == v ? max(p3.ff + p2.ff, p4.ff + p1.ff) : p3.ff + p1.ff);
ans = max(t, ans);
}
}
}
void solve(){
int n; cin >> n >> k;
for(int i = 1; i <= n; ++i) cin >> p[i];
for(int j = 1; j <= n - 1; ++j){
int u, v; cin >> u >> v;
adj[u].pb(v);
adj[v].pb(u);
}
for(int i = 1; i <= n; ++i){
for(auto v:adj[i]){
S[i] += p[v];
}
}
dfs(1);
dfs2(1);
cout << ans << '\n';
}
signed main(){
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int t = 1; //cin >> t;
while(t--) solve();
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
2684 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
2684 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
368 ms |
349652 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
2684 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |