#include "closing.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = array<ll,2>;
static inline int SZ(const auto &v){ return (int)v.size(); }
static pair<vector<ll>, vector<ll>> doConvex(vector<pii> p){
// For items where gap >= cost/2 (convex behaviour)
vector<ll> v;
v.reserve(2*SZ(p));
for (auto &ab: p){
ll a = ab[0], b = ab[1];
v.push_back(a);
v.push_back(b - a);
}
sort(v.begin(), v.end(), greater<ll>());
v.insert(v.begin(), 0);
for(int i=1;i<SZ(v);i++) v[i]+=v[i-1];
vector<ll> ans0, ans1;
for(int i=0;i<SZ(v);i++){
if((i+1)%2==0) ans0.push_back(v[i]);
else ans1.push_back(v[i]);
}
return {ans0, ans1};
}
static pair<vector<ll>, vector<ll>> doConcave(vector<pii> p){
// For items where gap < cost/2 (concave behaviour)
vector<ll> v0, v1;
sort(p.begin(), p.end(), [&](const pii& a, const pii& b){
return a[1] > b[1];
});
vector<ll> curmin(SZ(p)+1);
curmin[SZ(p)] = (ll)-2e18;
for(int i=SZ(p)-1;i>=0;i--){
curmin[i] = max(curmin[i+1], p[i][0]);
}
ll sum = 0;
v1.push_back(0);
for(int i=0;i<SZ(p);i++){
sum += p[i][1];
v1.push_back(sum);
}
sum = 0;
ll bamax = (ll)-2e18;
for(int i=0;i<SZ(p);i++){
sum += p[i][1];
bamax = max(bamax, p[i][0] - p[i][1]);
v0.push_back(max(sum + bamax, sum - p[i][1] + curmin[i]));
}
return {v0, v1};
}
static vector<ll> solveCoins(vector<pii> inp){
// Each off-path node v has costs:
// 0pt: 0
// 1pt: dx (min)
// 2pt: dy (max)
// We want minimal cost for each total points t (0..2n).
int n = SZ(inp);
ll sumDy = 0;
vector<pii> convex, concave;
for(auto &ab: inp){
ll a = ab[0], b = ab[1]; // assume a<=b
sumDy += b;
ll gap = b - a; // saving if downgrade 2pt->1pt
// transform to (gap, b)
if(gap * 2 >= b) convex.push_back({gap, b});
else concave.push_back({gap, b});
}
auto [odd1, even1] = doConcave(concave);
auto [odd2, even2] = doConvex(convex);
auto conv = [&](const vector<ll>& a, const vector<ll>& b)->vector<ll>{
if(a.empty()) return b;
if(b.empty()) return a;
vector<ll> dif;
ll base = a[0] + b[0];
for(int i=1;i<SZ(a);i++) dif.push_back(a[i]-a[i-1]);
for(int i=1;i<SZ(b);i++) dif.push_back(b[i]-b[i-1]);
sort(dif.begin(), dif.end(), greater<ll>());
vector<ll> v;
v.reserve(SZ(dif)+1);
v.push_back(base);
for(ll x: dif) v.push_back(v.back()+x);
return v;
};
vector<ll> bestSave(2*n + 1, 0);
{
auto comb = conv(odd1, odd2);
for(int i=0; i<SZ(comb) && 2*i+2 < SZ(bestSave); i++)
bestSave[2*i+2] = max(bestSave[2*i+2], comb[i]);
}
{
auto comb = conv(odd1, even2);
for(int i=0; i<SZ(comb) && 2*i+1 < SZ(bestSave); i++)
bestSave[2*i+1] = max(bestSave[2*i+1], comb[i]);
}
{
auto comb = conv(even1, odd2);
for(int i=0; i<SZ(comb) && 2*i+1 < SZ(bestSave); i++)
bestSave[2*i+1] = max(bestSave[2*i+1], comb[i]);
}
{
auto comb = conv(even1, even2);
for(int i=0; i<SZ(comb) && 2*i < SZ(bestSave); i++)
bestSave[2*i] = max(bestSave[2*i], comb[i]);
}
// base = choose 2pt for all => sumDy
// cost(points) = sumDy - maxSaving(lostPoints)
// We computed bestSave indexed by "lost points".
// Convert to minimal cost indexed by "gained points" and in increasing order.
for(auto &x: bestSave) x = sumDy - x;
reverse(bestSave.begin(), bestSave.end()); // now index = gained points
return bestSave; // size 2n+1, bestSave[t] = min cost to gain t points
}
static const int MAXN = 200000 + 5;
static vector<vector<pair<int,int>>> g;
static int parentArr[MAXN];
static ll distXY[2][MAXN];
static void dfs(int root, int p, int id){
stack<int> st;
st.push(root);
parentArr[root] = p;
while(!st.empty()){
int x = st.top(); st.pop();
for(auto [w,y]: g[x]){
if(y==parentArr[x]) continue;
parentArr[y]=x;
distXY[id][y] = distXY[id][x] + (ll)w;
st.push(y);
}
}
}
int max_score(int N, int X, int Y, long long K,
std::vector<int> U, std::vector<int> V, std::vector<int> W) {
g.assign(N, {});
for(int i=0;i<N-1;i++){
g[U[i]].push_back({W[i], V[i]});
g[V[i]].push_back({W[i], U[i]});
}
for(int i=0;i<N;i++) distXY[0][i]=distXY[1][i]=0;
// distXY[0] from X, distXY[1] from Y, and parentArr from dfs-root
dfs(Y, -1, 1);
dfs(X, -1, 0);
// build path X->Y using parent from dfs(X)?? we used parentArr from dfs(X),
// but parentArr currently corresponds to dfs(X) run, so chain from Y to X.
vector<int> onPath(N, 0);
vector<int> path;
for(int v=Y; v!=X; v=parentArr[v]){
path.push_back(v);
onPath[v]=1;
}
path.push_back(X);
reverse(path.begin(), path.end());
onPath[X]=1;
// Off-path items as (min(dx,dy), max(dx,dy))
vector<pii> coins;
coins.reserve(N);
for(int i=0;i<N;i++){
if(onPath[i]) continue;
ll dx = distXY[0][i], dy = distXY[1][i];
if(dx > dy) swap(dx, dy);
coins.push_back({dx, dy});
}
// s1[t] = min cost to gain t points from off-path nodes
vector<ll> s1 = solveCoins(coins);
// s2[j-1] = min cost to make exactly j nodes on path type-2
// (all path nodes are at least type-1; type-2 segment is contiguous on path)
vector<ll> s2(SZ(path), (ll)2e18);
{
int m = SZ(path);
vector<ll> saveLeft(m), saveRight(m);
ll tot = 0;
for(int i=0;i<m;i++){
int v = path[i];
ll dx = distXY[0][v], dy = distXY[1][v];
ll mx = max(dx, dy);
tot += mx; // baseline: all are type-2
saveLeft[i] = mx - dx; // saving if downgraded to type-1 via X
saveRight[i] = mx - dy; // saving if downgraded to type-1 via Y
}
int l=0, r=m;
while(l<r){
int segLen = r-l; // type-2 segment length
s2[segLen-1] = min(s2[segLen-1], tot);
// pick next downgrade on the side that saves more (greedy, but constrained to prefixes/suffixes)
if(saveLeft[l] > saveRight[r-1]){
tot -= saveLeft[l++];
}else{
tot -= saveRight[--r];
}
}
}
int ans = 0;
// Combine off-path points t and path type-2 count j (>=1 if we use s2)
int j = SZ(s2);
for(int t=0; t<SZ(s1); t++){
while(j>0 && s1[t] + s2[j-1] > K) j--;
if(j>0){
ans = max(ans, t + j + SZ(path)); // |path| baseline + j extra + t from off-path
}
}
// Also consider solutions that don't force "all path nodes type-1" structure:
// picking individually cheapest distances to X or Y (simple upper-bound baseline)
{
ll cursum = 0;
vector<ll> allCosts;
allCosts.reserve(2*N);
for(int i=0;i<N;i++){
allCosts.push_back(distXY[0][i]);
allCosts.push_back(distXY[1][i]);
}
sort(allCosts.begin(), allCosts.end());
for(int i=0;i<SZ(allCosts);i++){
cursum += allCosts[i];
if(cursum > K) break;
ans = max(ans, i+1);
}
}
return ans;
}