#include <bits/stdc++.h>
using namespace std;
using ll = long long;
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
ll k;
if (!(cin >> n >> k)) return 0;
vector<ll> b(n+1);
for (int i = 1; i <= n; ++i) cin >> b[i];
vector<vector<int>> g(n+1);
for (int i = 0; i < n-1; ++i) {
int x, y; cin >> x >> y;
g[x].push_back(y);
g[y].push_back(x);
}
if (n == 1) {
if (b[1] > k) cout << 1 << "\n";
else cout << 0 << "\n";
return 0;
}
auto farthest_from = [&](int s) {
vector<int> dist(n+1, -1);
queue<int> q;
q.push(s);
dist[s] = 0;
int best = s;
while (!q.empty()) {
int v = q.front(); q.pop();
if (dist[v] > dist[best]) best = v;
for (int u: g[v]) if (dist[u] == -1) {
dist[u] = dist[v] + 1;
q.push(u);
}
}
return best;
};
int a = farthest_from(1);
int bnode = farthest_from(a);
auto solve = [&](int root)->ll {
vector<int> parent(n+1, 0);
vector<int> order;
order.reserve(n);
vector<int> st;
st.push_back(root);
parent[root] = 0;
while (!st.empty()) {
int v = st.back(); st.pop_back();
order.push_back(v);
for (int u: g[v]) {
if (u == parent[v]) continue;
parent[u] = v;
st.push_back(u);
}
}
vector<ll> sum(n+1, 0);
for (int i = 1; i <= n; ++i) sum[i] = b[i];
ll cuts = 0;
for (int idx = (int)order.size() - 1; idx >= 0; --idx) {
int v = order[idx];
if (sum[v] > k) {
cuts++;
sum[v] = 0;
}
int p = parent[v];
if (p != 0 && sum[v] > 0) {
if (sum[p] + sum[v] > k) {
cuts++;
} else {
sum[p] += sum[v];
}
}
}
return cuts;
};
ll ans1 = solve(a);
ll ans2 = solve(bnode);
ll ans = min(ans1, ans2);
cout << ans << "\n";
return 0;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |