#include <bits/stdc++.h>
using namespace std;
#define for_(i, s, e) for (int i = s; i < (int) e; i++)
#define for__(i, s, e) for (ll i = s; i < e; i++)
typedef long long ll;
typedef vector<int> vi;
typedef pair<int, int> ii;
#define endl '\n'
const int MXN = 1e5, MXV = 100;
int n, v;
ll dp[MXN+1][MXV+1];
ll nums[MXN+1];
vi adj[MXN+1];
ll ans = 0;
void init(int p, int parent) {
for (int i: adj[p]) if (i != parent) init(i, p);
for_(ct, 0, v+1) {
ll s = 0, take = 0, noTake = 0;
for (int i: adj[p]) if (i != parent) {
s += nums[i];
if (ct) take = max(take, dp[i][ct-1]);
noTake = max(noTake, dp[i][ct]);
}
dp[p][ct] = noTake;
if (ct) dp[p][ct] = max(dp[p][ct], s+take);
}
}
vector<ll> prefMax(vector<ll> k) {
for_(i, 1, k.size()) k[i] = max(k[i], k[i-1]);
return k;
}
vector<ll> sufMax(vector<ll> k) {
for (int i = k.size()-2; i >= 0; i--) k[i] = max(k[i], k[i+1]);
return k;
}
void reroot(int p, int parent, ll p0, ll p1) {
//cout << "! " << p << " " << parent << " " << p0 << " " << p1 << endl;
int ct = adj[p].size();
vector<ll> take(ct), noTake(ct), takePref, takeSuf, noTakePref, noTakeSuf;
ll s = 0;
for_(i, 0, ct) {
if (adj[p][i] == parent) {
take[i] = p1; noTake[i] = p0;
} else {
take[i] = dp[adj[p][i]][v-1]; noTake[i] = dp[adj[p][i]][v];
}
s += nums[adj[p][i]];
}
takePref = prefMax(take); noTakePref = prefMax(noTake);
takeSuf = sufMax(take); noTakeSuf = sufMax(noTake);
//cout << p << " " << takeSuf[0] + s << " " << noTakeSuf[0] << endl;
ans = max({ans, takeSuf[0] + s, noTakeSuf[0]});
for_(i, 0, ct) {
if (adj[p][i] == parent) continue;
//reroot(adj[p][i], p, max(i > 0 ? noTakePref[i-1] : 0, i < ct-1 ? noTakeSuf[i+1] : 0), max(i > 0 ? takePref[i-1] : 0, i < ct-1 ? takeSuf[i+1] : 0) + s - nums[adj[p][i]]);
}
}
int main() {
#ifdef shiven
freopen("test.in", "r", stdin);
#endif
ios_base::sync_with_stdio(false);
cin.tie(0);
cin >> n >> v;
for_(i, 0, n) cin >> nums[i];
for_(i, 0, n-1) {
int a, b; cin >> a >> b;
a -= 1; b -= 1;
adj[a].push_back(b); adj[b].push_back(a);
}
if (v == 0) {
cout << 0 << endl;
return 0;
}
for_(i, 0, n) {
init(i, i);
reroot(i, i, 0, 0);
}
cout << ans << endl;
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2816 KB |
Output is correct |
2 |
Correct |
2 ms |
2688 KB |
Output is correct |
3 |
Correct |
2 ms |
2688 KB |
Output is correct |
4 |
Correct |
2 ms |
2688 KB |
Output is correct |
5 |
Correct |
2 ms |
2688 KB |
Output is correct |
6 |
Correct |
2 ms |
2688 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2816 KB |
Output is correct |
2 |
Correct |
2 ms |
2688 KB |
Output is correct |
3 |
Correct |
2 ms |
2688 KB |
Output is correct |
4 |
Correct |
2 ms |
2688 KB |
Output is correct |
5 |
Correct |
2 ms |
2688 KB |
Output is correct |
6 |
Correct |
2 ms |
2688 KB |
Output is correct |
7 |
Correct |
634 ms |
3704 KB |
Output is correct |
8 |
Correct |
54 ms |
3584 KB |
Output is correct |
9 |
Correct |
45 ms |
3704 KB |
Output is correct |
10 |
Correct |
632 ms |
3584 KB |
Output is correct |
11 |
Correct |
216 ms |
3584 KB |
Output is correct |
12 |
Correct |
88 ms |
3576 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
4082 ms |
89592 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2816 KB |
Output is correct |
2 |
Correct |
2 ms |
2688 KB |
Output is correct |
3 |
Correct |
2 ms |
2688 KB |
Output is correct |
4 |
Correct |
2 ms |
2688 KB |
Output is correct |
5 |
Correct |
2 ms |
2688 KB |
Output is correct |
6 |
Correct |
2 ms |
2688 KB |
Output is correct |
7 |
Correct |
634 ms |
3704 KB |
Output is correct |
8 |
Correct |
54 ms |
3584 KB |
Output is correct |
9 |
Correct |
45 ms |
3704 KB |
Output is correct |
10 |
Correct |
632 ms |
3584 KB |
Output is correct |
11 |
Correct |
216 ms |
3584 KB |
Output is correct |
12 |
Correct |
88 ms |
3576 KB |
Output is correct |
13 |
Execution timed out |
4082 ms |
89592 KB |
Time limit exceeded |
14 |
Halted |
0 ms |
0 KB |
- |