#include "factories.h"
#include <bits/stdc++.h>
std::vector<std::vector<std::pair<int,int>>> adj;
int treeSize[500010];
int size[500010];
bool dead[500010];
int rootDist[500010];
int centroidParent[500010];
std::vector<int> connectedCentroid[500010];
std::vector<long long> distTable[500010];
void dfs(int par,int curr){
if(par!=-1&&adj[curr].size()==1){
//this should work properly
treeSize[curr]=1;
}
for(auto [to,weight]:adj[curr]){
if(to==par)continue;
rootDist[to]=rootDist[curr]+weight;
dfs(curr,to);
treeSize[curr]+=treeSize[to];
}
}
int computeSize(int par,int curr){
size[curr]=1;
for(auto [to,weight]:adj[curr]){
if(to==par||dead[to])continue;
size[curr]+=computeSize(curr,to);
}
return size[curr];
}
int findCentroid(int par,int curr,int total){
for(auto [to,weight]:adj[curr]){
if(to==par||dead[to])continue;
if(size[to]>total/2)return findCentroid(curr,to,total);
}
return curr;
}
void dfs_dist(int par,int curr,int cen,long long dist){
connectedCentroid[curr].push_back(cen);
distTable[curr].push_back(dist);
for(auto [to,weight]:adj[curr]){
if(to==par|| dead[to]) continue;
dfs_dist(curr,to,cen,dist+weight);
}
}
void build(int par,int curr){
int total=computeSize(-1,curr);
int cen=findCentroid(-1,curr,total);
centroidParent[cen]=par;
dead[cen]=true;
dfs_dist(-1,cen,cen,0);
for(auto [to,weight]:adj[cen]){
if(dead[to])continue;
build(cen,to);
}
}
void Init(int N, int A[], int B[], int D[]) {
for(int i=0;i<N-1;i++){
adj[A[i]].push_back({B[i],D[i]});
adj[B[i]].push_back({A[i],D[i]});
}
build(-1,0);
}
const long long INF = 1e18;
int q=0;
long long shortest[500100];
int last[500100];
long long Query(int S, int X[], int T, int Y[]) {
q++;
for(int i=0;i<S;i++){
int focusNode = X[i];
for(int j=0;j<connectedCentroid[focusNode].size();j++){
if(last[connectedCentroid[focusNode][j]]!=q){
shortest[connectedCentroid[focusNode][j]]=INF;
last[connectedCentroid[focusNode][j]]=q;
}
shortest[connectedCentroid[focusNode][j]]=std::min(shortest[connectedCentroid[focusNode][j]],distTable[focusNode][j]);
}
}
long long best = INF;
for(int i=0;i<T;i++){
int focusNode = Y[i];
for(int j=0;j<connectedCentroid[focusNode].size();j++){
if(last[connectedCentroid[focusNode][j]]!=q)continue;
best=std::min(shortest[connectedCentroid[focusNode][j]]+distTable[focusNode][j],best);
}
}
return best;
}
/*
7 3
0 1 4
1 2 4
2 3 5
2 4 6
4 5 5
1 6 3
2 2
0 6
3 4
3 2
0 1 3
4 6
1 1
2
5
*/