#include <bits/stdc++.h>
#pragma GCC optimize("O3", "unroll-loops")
using namespace std;
using ll = long long;
const int N = 100'000 + 1;
int n, breads;
int pigeons[N];
vector<int> g[N];
// namespace sub1 {
// bool check() {
// return (n <= 10);
// }
// int parent[N];
// ll curPigeons[N];
// ll getDiff(int u, ll cur) {
// ll after = 0;
// while (u != 0) {
// after += curPigeons[u];
// u = parent[u];
// }
// return abs(after - cur);
// }
// ll dfs(int u, ll cur, int use) {
// ll res = 0;
// // doesn't drop the breadcrumb
// cur += curPigeons[u];
// res = max(res, getDiff(u, cur));
// for (const auto& v : g[u]) {
// if (v == parent[u]) continue;
// parent[v] = u;
// res = max(res, dfs(v, cur, use));
// }
// // drop the breadcrumb
// if (use < breads) {
// vector<ll> oldPigeons(n + 1, 0);
// oldPigeons[u] = curPigeons[u];
// for (const auto& v : g[u]) {
// oldPigeons[v] = curPigeons[v];
// }
// for (const auto& v : g[u]) {
// curPigeons[u] += curPigeons[v];
// curPigeons[v] = 0;
// }
// res = max(res, getDiff(u, cur));
// for (const auto& v : g[u]) {
// if (v == parent[u]) continue;
// parent[v] = u;
// res = max(res, dfs(v, cur, use + 1));
// }
// curPigeons[u] = oldPigeons[u];
// for (const auto& v : g[u]) {
// curPigeons[v] = oldPigeons[v];
// }
// }
// return res;
// }
// void solve() {
// ll res = 0;
// for (int root = 1; root <= n; ++ root) {
// for (int u = 1; u <= n; ++ u) {
// curPigeons[u] = pigeons[u];
// parent[u] = 0;
// }
// res = max(res, dfs(root, 0, 0));
// }
// cout << res;
// }
// }
// namespace sub2 {
// bool check() {
// return (n <= 1'000);
// }
// ll dp[N][110];
// void dfs(int u, int p) {
// ll bonus = -pigeons[p];
// for (const auto& v : g[u]) {
// bonus += pigeons[v];
// }
// for (int use = breads - 1; use >= 0; -- use) {
// dp[u][use + 1] = max(dp[u][use + 1], dp[u][use] + bonus);
// }
// for (const auto& v : g[u]) {
// if (v == p) continue;
// for (int use = 0; use <= breads; ++ use) {
// dp[v][use] = max(dp[v][use], dp[u][use]);
// }
// dfs(v, u);
// }
// }
// ll calc(int root) {
// for (int u = 0; u <= n + 1; ++ u) {
// for (int use = 0; use <= breads + 1; ++ use) {
// dp[u][use] = 0;
// }
// }
// dfs(root, 0);
// ll res = 0;
// for (int u = 1; u <= n; ++ u) {
// for (int use = 0; use <= breads; ++ use) {
// res = max(res, dp[u][use]);
// }
// }
// return res;
// }
// void solve() {
// ll res = 0;
// for (int root = 1; root <= n; ++ root) {
// res = max(res, calc(root));
// }
// cout << res;
// }
// }
namespace sub4 {
pair<ll, int> dpOut[N][101][2];
pair<ll, int> dpIn[N][101][2];
ll cost[N];
ll res = 0;
ll getCost(int u, int p) {
return cost[u] - pigeons[p];
}
void update(pair<ll, int> cur[], pair<ll, int> val) {
pair<ll, int> old = cur[0];
if (val.first > cur[0].first) {
cur[0] = val;
cur[1] = max(cur[1], old);
}
else {
cur[1] = max(cur[1], val);
}
}
void dfsInOut(int u, int p) {
// out is u can accept len >= 1
for (int use = 0; use <= breads; ++ use) {
dpOut[u][use][0] = (use == 0 ? make_pair(0ll, u) : make_pair(getCost(u, 0), u));
}
// in is u can accept len >= 2
for (int use = 0; use <= breads; ++ use) {
dpIn[u][use][0] = (use == 0 ? make_pair(0ll, -1) : make_pair(getCost(u, p), u));
}
for (const auto& v : g[u]) {
if (v == p) continue;
dfsInOut(v, u);
// case out:
// prefix u = v
for (int use = 0; use <= breads; ++ use) {
pair<ll, int> cur = make_pair(dpOut[v][use][0].first, v);
update(dpOut[u][use], cur);
if (use < breads) {
cur.first += getCost(u, v);
update(dpOut[u][use + 1], cur);
}
}
// case in:
// prefix v = u
for (int use = 0; use <= breads; ++ use) {
pair<ll, int> cur = make_pair(dpIn[v][use][0].first, v);
update(dpIn[u][use], cur);
if (use < breads) {
cur.first += getCost(u, p);
update(dpIn[u][use + 1], cur);
}
}
}
for (const auto& v : g[u]) {
if (v == p) continue;
for (int useOut = 0; useOut <= breads; ++ useOut) {
int useIn = breads - useOut;
for (int typeOut = 0; typeOut < 2; ++ typeOut) {
for (int typeIn = 0; typeIn < 2; ++ typeIn) {
if (dpOut[u][useOut][typeOut].second == v) continue;
res = max(res, dpOut[u][useOut][typeOut].first + dpIn[v][useIn][typeIn].first);
}
}
}
}
for (int useOut = 0; useOut <= breads; ++ useOut) {
res = max(res, dpOut[u][useOut][0].first);
}
}
void solve() {
for (int u = 1; u <= n; ++ u) {
for (const auto& v : g[u]) {
cost[v] += pigeons[u];
}
}
dfsInOut(1, 0);
cout << res;
}
}
int32_t main() {
std::ios_base::sync_with_stdio(false);
std::cin.tie(nullptr); std::cout.tie(nullptr);
cin >> n >> breads;
for (int u = 1; u <= n; ++ u) cin >> pigeons[u];
for (int i = 2; i <= n; ++ i) {
int u, v; cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
// if (sub1::check()) {
// sub1::solve();
// return 0;
// }
// if (sub2::check()) {
// sub2::solve();
// return 0;
// }
// cout << sub2::calc(1);
sub4::solve();
return (0 ^ 0);
}
/// Code by vuavisao
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2652 KB |
Output is correct |
2 |
Correct |
1 ms |
2652 KB |
Output is correct |
3 |
Correct |
1 ms |
2652 KB |
Output is correct |
4 |
Correct |
1 ms |
2652 KB |
Output is correct |
5 |
Correct |
1 ms |
2652 KB |
Output is correct |
6 |
Correct |
1 ms |
2652 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2652 KB |
Output is correct |
2 |
Correct |
1 ms |
2652 KB |
Output is correct |
3 |
Correct |
1 ms |
2652 KB |
Output is correct |
4 |
Correct |
1 ms |
2652 KB |
Output is correct |
5 |
Correct |
1 ms |
2652 KB |
Output is correct |
6 |
Correct |
1 ms |
2652 KB |
Output is correct |
7 |
Correct |
6 ms |
9052 KB |
Output is correct |
8 |
Correct |
4 ms |
9052 KB |
Output is correct |
9 |
Correct |
4 ms |
9052 KB |
Output is correct |
10 |
Correct |
5 ms |
9052 KB |
Output is correct |
11 |
Correct |
4 ms |
9052 KB |
Output is correct |
12 |
Correct |
4 ms |
9052 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
282 ms |
524288 KB |
Execution killed with signal 9 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2652 KB |
Output is correct |
2 |
Correct |
1 ms |
2652 KB |
Output is correct |
3 |
Correct |
1 ms |
2652 KB |
Output is correct |
4 |
Correct |
1 ms |
2652 KB |
Output is correct |
5 |
Correct |
1 ms |
2652 KB |
Output is correct |
6 |
Correct |
1 ms |
2652 KB |
Output is correct |
7 |
Correct |
6 ms |
9052 KB |
Output is correct |
8 |
Correct |
4 ms |
9052 KB |
Output is correct |
9 |
Correct |
4 ms |
9052 KB |
Output is correct |
10 |
Correct |
5 ms |
9052 KB |
Output is correct |
11 |
Correct |
4 ms |
9052 KB |
Output is correct |
12 |
Correct |
4 ms |
9052 KB |
Output is correct |
13 |
Runtime error |
282 ms |
524288 KB |
Execution killed with signal 9 |
14 |
Halted |
0 ms |
0 KB |
- |