#include <bits/stdc++.h>
using namespace std;
#ifndef LOCAL
#define fprintf(...) {}
#endif
using ll = long long;
const int N = 1e5 + 5;
template<class X, class Y>
inline void maxi(X &x, Y y) {
if (x < y) x = y;
}
vector<int> adj[N], g[N];
int n, k, a[N];
ll s[N];
ll ans = 0, INF;
struct Node {
ll val[101][2];
Node() {
memset(val, -63, sizeof val);
}
void maxi_(const Node &other) {
for (int i = 0; i <= k; i++)
for (int j = 0; j < 2; j++)
maxi(val[i][j], other.val[i][j]);
}
void deal(const Node &other, const int &x, const int &y) { // update from y to x
for (int i = 0; i <= k; i++) {
for (int j = 0; j < 2; j++) {
ll cur = other.val[i][j];
// if (i < k) maxi(val[i + 1][1], cur + s[x] - a[y] + (j ? a[y] : 0));
if (i < k) {
if (j) maxi(val[i + 1][1], cur + s[x] - a[y]);
else maxi(val[i + 1][1], cur + s[x] - a[y]);
maxi(val[i + 1][1], cur + s[x] - a[y] - (j ? a[x] : 0));
}
// maxi(val[i][0], cur + (j ? a[y] : 0));
maxi(val[i][0], cur/* + (j ? a[x] : 0)*/);
}
}
}
Node deal_(const int &x, const int &y) { // update from y to x
Node ans;
for (int i = 0; i <= k; i++) {
for (int j = 0; j < 2; j++) {
ll cur = val[i][j];
// if (i < k) maxi(ans.val[i + 1][1], cur + s[x] - a[y] + (j ? a[y] : 0));
if (i < k) {
if (j) maxi(ans.val[i + 1][1], cur + s[x] - a[y]);
else maxi(ans.val[i + 1][1], cur + s[x] - a[y]);
maxi(ans.val[i + 1][1], cur + s[x] - a[y] - (j ? a[x] : 0));
}
// maxi(ans.val[i][0], cur + (j ? a[y] : 0));
maxi(ans.val[i][0], cur/* + (j ? a[x] : 0)*/);
}
}
return ans;
}
ll res() const {
ll ans = 0;
for (int i = 0; i <= k; i++)
for (int j = 0; j < 2; j++)
maxi(ans, val[i][j]);
return ans;
}
void debug(int x) {
#ifndef LOCAL
return;
#endif
cerr << string(15, '=') << '\n';
for (int i = 0; i <= k; i++)
for (int j = 0; j < 2; j++)
fprintf(stderr, "dp[%d][%d][%d] = %lld\n", x, i, j, val[i][j]);
cerr << string(15, '=') << '\n';
}
};
vector<Node> pref[N], suf[N];
Node dp[N];
void dfs_init(int x, int p) {
for (int y: adj[x]) {
if (y == p) continue;
g[x].emplace_back(y);
dfs_init(y, x);
}
}
void dfs(int x) {
pref[x].resize(g[x].size());
suf[x].resize(g[x].size());
for (int y: g[x]) dfs(y);
for (int i = 0, o = g[x].size(); i < o; i++) {
int y = g[x][i];
auto &p = pref[x][i];
if (i) p.maxi_(pref[x][i - 1]);
p.deal(dp[y], x, y);
}
for (int o = g[x].size(), i = o - 1; i >= 0; i--) {
int y = g[x][i];
auto &p = suf[x][i];
if (i < o - 1) p.maxi_(suf[x][i + 1]);
p.deal(dp[y], x, y);
}
if (!g[x].empty()) dp[x] = pref[x].back();
// else {
maxi(dp[x].val[0][0], 0);
if (k > 0) maxi(dp[x].val[1][1], s[x]);
// }
maxi(ans, dp[x].res());
dp[x].debug(x);
}
void dfs_reroot(int x, Node down) {
maxi(down.val[0][0], 0);
if (k > 0) maxi(down.val[1][1], s[x]);
maxi(ans, down.res());
for (int i = 0, o = g[x].size(); i < o; i++) {
int y = g[x][i];
Node cur = down;
if (i > 0) cur.maxi_(pref[x][i - 1]);
if (i + 1 < o) cur.maxi_(suf[x][i + 1]);
dfs_reroot(y, cur.deal_(y, x));
}
}
int32_t main() {
#ifndef LOCAL
cin.tie(0)->sync_with_stdio(0); // ====================================================
#endif
cin >> n >> k;
for (int i = 1; i <= n; i++)
cin >> a[i];
for (int i = 1; i < n; i++) {
int x, y;
cin >> x >> y;
adj[x].emplace_back(y);
adj[y].emplace_back(x);
}
for (int x = 1; x <= n; x++) {
for (int y: adj[x])
s[x] += a[y];
fprintf(stderr, "s[%d] = %lld\n", x, s[x]);
}
dfs_init(1, 0);
dfs(1);
dfs_reroot(1, Node()); // ============================================================
cout << ans;
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
21 ms |
169048 KB |
Output is correct |
2 |
Correct |
22 ms |
169104 KB |
Output is correct |
3 |
Correct |
22 ms |
169052 KB |
Output is correct |
4 |
Correct |
22 ms |
169036 KB |
Output is correct |
5 |
Correct |
22 ms |
169048 KB |
Output is correct |
6 |
Correct |
21 ms |
169040 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
21 ms |
169048 KB |
Output is correct |
2 |
Correct |
22 ms |
169104 KB |
Output is correct |
3 |
Correct |
22 ms |
169052 KB |
Output is correct |
4 |
Correct |
22 ms |
169036 KB |
Output is correct |
5 |
Correct |
22 ms |
169048 KB |
Output is correct |
6 |
Correct |
21 ms |
169040 KB |
Output is correct |
7 |
Correct |
27 ms |
174672 KB |
Output is correct |
8 |
Correct |
26 ms |
174808 KB |
Output is correct |
9 |
Correct |
24 ms |
172368 KB |
Output is correct |
10 |
Correct |
25 ms |
172376 KB |
Output is correct |
11 |
Correct |
24 ms |
172372 KB |
Output is correct |
12 |
Correct |
23 ms |
172376 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
373 ms |
524288 KB |
Execution killed with signal 9 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
21 ms |
169048 KB |
Output is correct |
2 |
Correct |
22 ms |
169104 KB |
Output is correct |
3 |
Correct |
22 ms |
169052 KB |
Output is correct |
4 |
Correct |
22 ms |
169036 KB |
Output is correct |
5 |
Correct |
22 ms |
169048 KB |
Output is correct |
6 |
Correct |
21 ms |
169040 KB |
Output is correct |
7 |
Correct |
27 ms |
174672 KB |
Output is correct |
8 |
Correct |
26 ms |
174808 KB |
Output is correct |
9 |
Correct |
24 ms |
172368 KB |
Output is correct |
10 |
Correct |
25 ms |
172376 KB |
Output is correct |
11 |
Correct |
24 ms |
172372 KB |
Output is correct |
12 |
Correct |
23 ms |
172376 KB |
Output is correct |
13 |
Runtime error |
373 ms |
524288 KB |
Execution killed with signal 9 |
14 |
Halted |
0 ms |
0 KB |
- |