#include <vector>
#include <algorithm>
#include <cmath>
#include <queue>
using namespace std;
const long long INF = 2e18;
struct Edge {
int to;
long long w;
bool on_P;
};
long long max_score(int N, int X, int Y, long long K, std::vector<int> U, std::vector<int> V, std::vector<int> W) {
vector<vector<pair<int, long long>>> adj(N);
for (int i = 0; i < N - 1; ++i) {
adj[U[i]].push_back({V[i], W[i]});
adj[V[i]].push_back({U[i], W[i]});
}
vector<long long> dX(N, -1), dY(N, -1);
auto bfs = [&](int start, vector<long long>& dist) {
priority_queue<pair<long long, int>, vector<pair<long long, int>>, greater<pair<long long, int>>> pq;
pq.push({0, start});
dist[start] = 0;
while (!pq.empty()) {
auto [d, u] = pq.top();
pq.pop();
if (d > dist[u]) continue;
for (auto& edge : adj[u]) {
int v = edge.first;
long long w = edge.second;
if (dist[v] == -1 || dist[v] > dist[u] + w) {
dist[v] = dist[u] + w;
pq.push({dist[v], v});
}
}
}
};
bfs(X, dX);
bfs(Y, dY);
if (N <= 3000) { // 精确的树形背包DP,解决绝大部分子任务
vector<int> parent(N, -1);
queue<int> q;
q.push(X);
vector<bool> vis(N, false);
vis[X] = true;
while (!q.empty()) {
int u = q.front();
q.pop();
if (u == Y) break;
for (auto& edge : adj[u]) {
int v = edge.first;
if (!vis[v]) {
vis[v] = true;
parent[v] = u;
q.push(v);
}
}
}
vector<bool> is_on_P(N, false);
int curr = Y;
while (curr != -1) {
is_on_P[curr] = true;
curr = parent[curr];
}
vector<vector<Edge>> tree(N);
auto build_tree = [&](auto& self, int u, int p) -> void {
for (auto& edge : adj[u]) {
int v = edge.first;
if (v == p) continue;
bool edge_on_P = is_on_P[u] && is_on_P[v];
tree[u].push_back({v, edge.second, edge_on_P});
self(self, v, u);
}
};
build_tree(build_tree, X, -1);
auto dfs_dp = [&](auto& self, int u) -> vector<vector<long long>> {
vector<vector<long long>> dp(4, vector<long long>(3, INF));
dp[0][0] = 0;
dp[1][1] = dX[u];
dp[2][1] = dY[u];
dp[3][2] = max(dX[u], dY[u]);
int sz_u = 1;
for (auto& edge : tree[u]) {
int v = edge.to;
auto dp_v = self(self, v);
int sz_v = (dp_v[0].size() - 1) / 2;
vector<vector<long long>> next_dp(4, vector<long long>(2 * (sz_u + sz_v) + 1, INF));
for (int su = 0; su < 4; ++su) {
for (int sv = 0; sv < 4; ++sv) {
bool ok = true;
// X树约束验证
if (sv == 1 && !(su == 1 || su == 3)) ok = false;
if (sv == 3 && !(su == 1 || su == 3)) ok = false;
// Y树约束验证
if (edge.on_P) {
if (su == 2 && !(sv == 2 || sv == 3)) ok = false;
if (su == 3 && !(sv == 2 || sv == 3)) ok = false;
} else {
if (sv == 2 && !(su == 2 || su == 3)) ok = false;
if (sv == 3 && !(su == 2 || su == 3)) ok = false;
}
if (!ok) continue;
for (int cu = 0; cu <= 2 * sz_u; ++cu) {
if (dp[su][cu] == INF) continue;
for (int cv = 0; cv <= 2 * sz_v; ++cv) {
if (dp_v[sv][cv] == INF) continue;
next_dp[su][cu + cv] = min(next_dp[su][cu + cv], dp[su][cu] + dp_v[sv][cv]);
}
}
}
}
sz_u += sz_v;
dp = move(next_dp);
}
return dp;
};
auto final_dp = dfs_dp(dfs_dp, X);
long long ans = 0;
for (int s = 0; s < 4; ++s) {
for (int c = 0; c < final_dp[s].size(); ++c) {
if (final_dp[s][c] <= K) ans = max(ans, (long long)c);
}
}
return ans;
} else { // 贪心兜底,用以规避大节点O(N^2)超时(利用边际代价推导)
vector<int> parent(N, -1);
queue<int> q; q.push(X);
vector<bool> vis(N, false);
vis[X] = true;
while (!q.empty()) {
int u = q.front(); q.pop();
if (u == Y) break;
for (auto& edge : adj[u]) {
int v = edge.first;
if (!vis[v]) {
vis[v] = true;
parent[v] = u;
q.push(v);
}
}
}
vector<bool> on_path(N, false);
int curr = Y;
while (curr != -1) {
on_path[curr] = true;
curr = parent[curr];
}
vector<int> root_p(N, -1);
q = queue<int>();
for (int i = 0; i < N; ++i) {
if (on_path[i]) {
root_p[i] = i;
q.push(i);
}
}
while (!q.empty()) {
int u = q.front(); q.pop();
for (auto& edge : adj[u]) {
int v = edge.first;
if (root_p[v] == -1) {
root_p[v] = root_p[u];
q.push(v);
}
}
}
vector<long long> ones;
struct TwoItem {
long long cost, brk;
bool operator<(const TwoItem& other) const { return cost < other.cost; }
};
vector<TwoItem> twos;
for (int i = 0; i < N; ++i) {
int p = root_p[i];
long long delta = abs(dX[p] - dY[p]);
long long v = min(dX[i], dY[i]);
if (v < delta) {
ones.push_back(v);
ones.push_back(delta);
} else {
twos.push_back({v + delta, v});
}
}
sort(ones.begin(), ones.end());
sort(twos.begin(), twos.end());
vector<long long> P1(ones.size() + 1, 0);
for (size_t i = 0; i < ones.size(); ++i) P1[i + 1] = P1[i] + ones[i];
vector<long long> P2(twos.size() + 1, 0);
for (size_t i = 0; i < twos.size(); ++i) P2[i + 1] = P2[i] + twos[i].cost;
vector<long long> min_break(twos.size() + 1, 2e18);
for (int i = (int)twos.size() - 1; i >= 0; --i) {
min_break[i] = min(min_break[i + 1], twos[i].brk);
}
long long max_pts = 0;
for (size_t c = 0; c <= ones.size(); ++c) {
long long budget_left = K - P1[c];
if (budget_left < 0) break;
int low = 0, high = twos.size(), c2 = 0;
while (low <= high) {
int mid = low + (high - low) / 2;
if (P2[mid] <= budget_left) {
c2 = mid;
low = mid + 1;
} else {
high = mid - 1;
}
}
long long current_pts = c + 2LL * c2;
max_pts = max(max_pts, current_pts);
if (c2 < (int)twos.size()) {
long long budget_after_twos = budget_left - P2[c2];
if (budget_after_twos >= min_break[c2]) {
max_pts = max(max_pts, current_pts + 1);
}
}
}
return max_pts;
}
}