#include "highway.h"
#define fi first
#define se second
#define ll long long
using namespace std;
// progression
// solve sub2 (done!), sub4 then try sub6
vector<int> dist, parEdge, col, colVertex;
vector<vector<pair<int, int>>> adj;
void dfs(int u, int d, int c, int par = -1){
colVertex[u] = c;
dist[u] = d;
for (pair<int, int> pa: adj[u]){
int v = pa.fi;
if (v == par) continue;
parEdge[v] = pa.se;
col[pa.se] = c;
dfs(v, d+1, c, u);
}
}
int sub2(int n, vector<int> u, vector<int> v, int a, int b, int len, int inSub){
int m = u.size();
vector<int> cand, w;
for (int i = 0; i < n; i++){
if (dist[i] == len && colVertex[i] == inSub){
cand.push_back(i);
// cout << i << ' ';
}
}
// cout << endl;
// binary search for t
int l = 0, r = cand.size()-1;
while(l < r){
// cout << l << ' ' << r << endl;
int mid = (l + r)/2;
// for l to mid, check if increase or not
for (int i = 0; i < m; i++){
w[i] = 0;
}
for (int i = l; i <= mid; i++){
w[parEdge[cand[i]]] = 1;
}
long long chck = ask(w);
if (chck > len*a){
// move to r
r = mid;
// cout << "move r to " << mid << endl;
}
else{
l = mid+1;
// cout << "move l to " << mid+1 << endl;
}
}
// ans at r
return cand[r];
}
void find_pair(int n, vector<int> u, vector<int> v, int ap, int bp) {
long long a = (long long)ap, b = (long long) bp;
int m = u.size();
vector<int> w(m);
for (int i = 0; i < m; i++){
w[i] = 0;
}
long long totLen = ask(w) / (long long)a;
col.assign(m, 0);
colVertex.assign(n, 0);
dist.assign(n, 0);
parEdge.assign(n, 0);
adj.assign(n, vector<pair<int, int>>());
// binary search for the edge on
int l = 0, r = m-1;
while(l < r){
int mid = (l + r)/2;
// for l to mid, check if increase or not
for (int i = 0; i < m; i++){
w[i] = 0;
}
for (int i = l; i <= mid; i++){
w[i] = 1;
}
long long chck = ask(w);
if (chck > totLen*a){
// move to r
r = mid;
// cout << "move r to " << mid << endl;
}
else{
l = mid+1;
// cout << "move l to " << mid+1 << endl;
}
}
// r is the edge confirmed to be in the path
int st1 = u[r], st2 = v[r];
for (int i = 0; i < m; i++){
adj[u[i]].push_back({v[i], i});
adj[v[i]].push_back({u[i], i});
}
// dfs to seperate into 2 subtrees
dfs(st1, 0, 1, st2);
dfs(st2, 0, 2, st1);
// all edges are either col 1, 2 or 0 (r)
long long q1, q2;
// set 1 to a, 2 to b
for (int i = 0; i < m; i++){
if (col[i] == 1 || col[i] == 0){
w[i] = 0;
}
else{
w[i] = 1;
}
}
q1 = ask(w);
for (int i = 0; i < m; i++){
if (col[i] == 2 || col[i] == 0){
w[i] = 0;
}
else{
w[i] = 1;
}
}
q2 = ask(w);
long long sum = (q1 + q2 - 2*a) / (a + b);
long long dif = (q1 - q2) / (a - b);
long long len1 = (sum + dif) / 2;
long long len2 = (sum - dif) / 2;
int s = sub2(n, u, v, a, b, len1, 1);
int t = sub2(n, u, v, a, b, len2, 2);
answer(s, t);
}