#include <bits/stdc++.h>
#pragma GCC optimize("O3", "unroll-loops")
using namespace std;
using ll = long long;
const int N = 100'000 + 10;
int n, breads;
int pigeons[N];
vector<int> g[N];
namespace sub1 {
bool check() {
return (n <= 10);
}
int parent[N];
ll curPigeons[N];
ll getDiff(int u, ll cur) {
ll after = 0;
while (u != 0) {
after += curPigeons[u];
u = parent[u];
}
return abs(after - cur);
}
ll dfs(int u, ll cur, int use) {
ll res = 0;
// doesn't drop the breadcrumb
cur += curPigeons[u];
res = max(res, getDiff(u, cur));
for (const auto& v : g[u]) {
if (v == parent[u]) continue;
parent[v] = u;
res = max(res, dfs(v, cur, use));
}
// drop the breadcrumb
if (use < breads) {
vector<ll> oldPigeons(n + 1, 0);
oldPigeons[u] = curPigeons[u];
for (const auto& v : g[u]) {
oldPigeons[v] = curPigeons[v];
}
for (const auto& v : g[u]) {
curPigeons[u] += curPigeons[v];
curPigeons[v] = 0;
}
res = max(res, getDiff(u, cur));
for (const auto& v : g[u]) {
if (v == parent[u]) continue;
parent[v] = u;
res = max(res, dfs(v, cur, use + 1));
}
curPigeons[u] = oldPigeons[u];
for (const auto& v : g[u]) {
curPigeons[v] = oldPigeons[v];
}
}
return res;
}
void solve() {
ll res = 0;
for (int root = 6; root <= 6; ++ root) {
for (int u = 1; u <= n; ++ u) {
curPigeons[u] = pigeons[u];
parent[u] = 0;
}
res = max(res, dfs(root, 0, 0));
}
cout << res;
}
}
namespace sub2 {
bool check() {
return (n <= 1'000);
}
ll dp[N][110];
void dfs(int u, int p) {
ll bonus = -pigeons[p];
for (const auto& v : g[u]) {
bonus += pigeons[v];
}
for (const auto& v : g[u]) {
if (v == p) continue;
for (int use = 0; use < breads; ++ use) {
dp[v][use + 1] = max(dp[v][use], dp[u][use] + bonus);
}
dfs(v, u);
}
}
ll calc(int root) {
for (int u = 0; u <= n + 1; ++ u) {
for (int use = 0; use <= breads + 1; ++ use) {
dp[u][use] = 0;
}
}
dfs(root, 0);
ll res = 0;
for (int u = 1; u <= n; ++ u) {
for (int use = 0; use <= breads; ++ use) {
res = max(res, dp[u][use]);
}
}
return res;
}
void solve() {
ll res = 0;
for (int root = 1; root <= n; ++ root) {
res = max(res, calc(root));
}
cout << res;
}
}
int32_t main() {
std::ios_base::sync_with_stdio(false);
std::cin.tie(nullptr); std::cout.tie(nullptr);
cin >> n >> breads;
for (int u = 1; u <= n; ++ u) cin >> pigeons[u];
for (int i = 2; i <= n; ++ i) {
int u, v; cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
if (sub1::check()) {
sub1::solve();
return 0;
}
if (sub2::check()) {
sub2::solve();
return 0;
}
cout << sub2::calc(1);
return (0 ^ 0);
}
/// Code by vuavisao
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
2652 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
2652 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
114 ms |
96336 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
2652 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |