#include "closing.h"
#define pb push_back
#include <bits/stdc++.h>
#include <vector>
using namespace std;
const int maxn = 2e5 + 10;
long long n, x, y, k;
vector < pair < int, int > > g[maxn];
long long distx[maxn], disty[maxn];
int used[maxn];
void dfsy(int beg, long long dep)
{
used[beg] = 1;
disty[beg] = min(disty[beg], dep);
for (auto &[to, w]: g[beg])
{
if(used[to])continue;
dfsy(to, dep + w);
}
}
void dfsx(int beg, long long dep)
{
used[beg] = 1;
distx[beg] = min(distx[beg], dep);
for (auto &[to, w]: g[beg])
{
if(used[to])continue;
dfsx(to, dep + w);
}
}
int p[maxn], a[maxn];
int st0, st1;
int is[maxn];
void dfs0(int beg, int par, int i)
{
if(beg == x || beg == y)
{
if(st0 == -1)st0 = beg;
else st1 = beg;
}
p[beg] = i;
a[i] = beg;
for (auto &[to, w]: g[beg])
{
if(par == to)continue;
dfs0(to, beg, i+1);
}
}
long long currdist[maxn];
int max_score(int N, int X, int Y, long long K,
std::vector<int> U, std::vector<int> V, std::vector<int> W)
{
n = N;
x = X;
y = Y;
k = K;
for (int i = 0; i < n; ++ i)
g[i].clear();
for (int i = 0; i < n-1; ++ i)
{
int from = U[i];
int to = V[i];
int t = W[i];
g[from].pb(make_pair(to, t));
g[to].pb(make_pair(from, t));
}
int endpoint = 0;
for (int i = 0; i < n; ++ i)
{
if(g[i].size() == 1)endpoint = i;
}
st0 = -1;
st1 = -1;
dfs0(endpoint, -1, 1);
x = st0;
y = st1;
for (int i = 0; i < n; ++ i)
distx[i] = 1e18+1;
for (int i = 0; i < n; ++ i)
disty[i] = 1e18+1;
for (int i = 0; i < n; ++ i)
used[i] = 0;
dfsx(x, 0);
for (int i = 0; i < n; ++ i)
used[i] = 0;
dfsy(y, 0);
long long ans = 0;
for (int pref = 1; pref <= x; ++ pref) /// kym x
{
for (int suff = 1; suff <= y; ++ suff) /// kym y
{
for (int i = 1; i <= pref; ++ i)
currdist[i] = max(distx[a[i]], currdist[i]);
for (int i = suff; i <= n; ++ i)
currdist[i] = max(disty[a[i]], currdist[i]);
vector < pair < int, int > > v;
for (int i = 1; i <= n; ++ i)
{
if(i <= pref && i >= suff)v.pb(make_pair(currdist[i] * 2, 1));
else v.pb(make_pair(currdist[i], 2));
}
sort(v.begin(), v.end());
long long kk = k, cnt = 0;
for (auto &[x, op]: v)
{
if(op == 1)
{
x /= 2;
}
if(kk >= x)
{
cnt += op;
kk -= x;
}
}
ans = max(ans, cnt);
}
}
return ans;
}