#include "closing.h"
#include <bits/stdc++.h>
#define int long long
#define pii pair<int,int>
#define ff first
#define ss second
#define sgn signed
const int inf = 4e18;
using namespace std;
struct node{
int sum=0,cnt=0;
node operator+(node b)
{
node res;
res.sum=sum+b.sum;
res.cnt=cnt+b.cnt;
return res;
}
void operator+=(node b)
{
cnt+=b.cnt;
sum+=b.sum;
}
};
struct seggy{
int n;
vector<node> t;
seggy(int sz)
{
n = sz;
t.resize(4*n,node());
}
node query(int l,int r,int L,int R,int idx)
{
if(l>R||r<L)return node();
if(l>=L&&r<=R)return t[idx];
int mid = l+r>>1;
return query(l,mid,L,R,2*idx)+query(mid+1,r,L,R,2*idx+1);
}
node query(int l,int r)
{
return query(0,n-1,l,r,1);
}
void update(int l,int r,int idx,int i,int x)
{
if(l>i||r<i)return;
if(l==r)
{
t[idx].sum = max(0ll,x);
t[idx].cnt = (x>=0);
return;
}
int mid = l+r>>1;
update(l,mid,2*idx,i,x);
update(mid+1,r,2*idx+1,i,x);
t[idx] = t[2*idx]+t[2*idx+1];
}
void update(int i,int x)
{
update(0,n-1,1,i,x);
}
int csum(int l,int r,int sum, int idx,int tar,int tot)
{
if(l==r)return sum+(t[idx].cnt+tot<=tar?t[idx].sum:0);
int mid = l+r>>1;
if(t[2*idx].cnt+tot<=tar)
return csum(mid+1,r,sum+t[2*idx].sum,2*idx+1,tar,tot+t[2*idx].cnt);
return csum(l,mid,sum,2*idx,tar,tot);
}
int cidx(int l,int r,int sum,int idx,int tar,int tot)
{
if(l==r)return tot+(t[idx].sum+sum<=tar?t[idx].cnt:0);
int mid = l+r>>1;
if(t[2*idx].sum+sum<=tar)
return cidx(mid+1,r,sum+t[2*idx].sum,2*idx+1,tar,tot+t[2*idx].cnt);
return cidx(l,mid,sum,2*idx,tar,tot);
}
};
sgn max_score(sgn n, sgn x, sgn y, int l,
std::vector<sgn> u, std::vector<sgn>v, std::vector<sgn>w)
{
vector<vector<pii>> graph(n);
for(int i = 0; i < n-1; i++)
graph[u[i]].push_back({v[i],w[i]}),
graph[v[i]].push_back({u[i],w[i]});
vector<array<int,2>> d(n,{0,0});
vector<int> p(n);
function<void(int,int,int)> dfs = [&](int cur,int prev,int idx)
{
for(pii a : graph[cur])
if(a.ff!=prev)
d[a.ff][idx]=a.ss+d[cur][idx],
dfs(a.ff,cur,idx);
p[cur] = prev;
};
dfs(x,x,0);
dfs(y,y,1);
for(auto&x:d)
if(x[1]<x[0])swap(x[1],x[0]);
vector<int> id(n),idl(n),idr(n);
for(int i = 0; i< n; i++)
id[i]=i,idr[i]=i,idl[i]=i;
sort(id.begin(),id.end(),[&](int a,int b){return abs(d[a][0]-d[a][1])<abs(d[b][0]-d[b][1])
||(abs(d[a][0]-d[a][1])==abs(d[b][0]-d[b][1]) && d[a][1]<d[b][1]);});
sort(idl.begin(),idl.end(),[&](int a,int b){return max(d[a][0],d[a][1])<max(d[b][0],d[b][1]);});
sort(idr.begin(),idr.end(),[&](int a,int b){return min(d[a][0],d[a][1])<min(d[b][0],d[b][1]);});
vector<int> neu(n);
for(int i = 0; i < n; i++)
neu[idl[i]] = i;
idl = neu;
for(int i = 0; i < n ;i++)
neu[idr[i]] = i;
idr = neu;
seggy sa(n),sb(n);
vector<int> pth{x};
for(int i : id)
sb.update(idr[i],d[i][0]);
function<int(int)> calc = [&](int idx)->int
{
auto val = sa.csum(0,n-1,0,1,idx,0);
if(l<val)return -pth.size();
return 2*idx + sb.cidx(0,n-1,0,1,l-val,0);
};
function<int()> gt = [&](){
int l = 0,r = sa.t[1].cnt;
while(r-l)
{
int mid = l+r>>1;
if(calc(mid+1)-calc(mid)>0)
l = mid+1;
else
r = mid;
}
return calc(l);
};
int ans = gt();
while(pth.back()!=y)
pth.push_back(p[pth.back()]);
for(int i : pth)
l-=d[i][0],d[i][0] = d[i][1]-d[i][0],d[i][1] = inf;
sort(id.begin(),id.end(),[&](int a,int b){return abs(d[a][0]-d[a][1])<abs(d[b][0]-d[b][1])
||(abs(d[a][0]-d[a][1])==abs(d[b][0]-d[b][1]) && d[a][1]<d[b][1]);});
sort(idl.begin(),idl.end(),[&](int a,int b){return max(d[a][0],d[a][1])<max(d[b][0],d[b][1]);});
sort(idr.begin(),idr.end(),[&](int a,int b){return min(d[a][0],d[a][1])<min(d[b][0],d[b][1]);});
for(int i = 0; i < n; i++)
neu[idl[i]] = i;
idl = neu;
for(int i = 0; i < n ;i++)
neu[idr[i]] = i;
idr = neu;
for(int i : id)
sb.update(idr[i],d[i][0]);
ans = max<int>(gt()+pth.size(),ans);
for(int i : id)
{
if(d[i][1] == inf)break;
sa.update(idl[i],d[i][1]);
for(int j = 0;j+2 <= sa.t[1].cnt;j++)
{
if(calc(j+2)>-pth.size())
assert(calc(j+2) - calc(j+1) <= calc(j+1)-calc(j));
}
sb.update(idr[i],-1);
ans = max<int>(gt()+pth.size(),ans);
}
return ans;
}