#include "highway.h"
#define int long long
#define fi first
#define se second
#define ll int
using namespace std;
// progression
// solve sub2 (done!), sub4 (done!) then try sub6
vector<int> dist, parEdge, col, colVertex;
vector<vector<pair<int, int>>> adj;
void dfs(int u, int d, int c, int par = -1){
colVertex[u] = c;
dist[u] = d;
for (pair<int, int> pa: adj[u]){
int v = pa.fi;
if (v == par) continue;
parEdge[v] = pa.se;
col[pa.se] = c;
dfs(v, d+1, c, u);
}
}
int sub2(int n, int a, int b, int len, int inSub, int base){
int m = n-1;
vector<signed> cand, w(m, 0);
for (int i = 0; i < n; i++){
if (dist[i] == len && colVertex[i] == inSub){
cand.push_back(i);
}
}
int l = 0, r = (int)cand.size() - 1;
while (l < r){
int mid = (l + r) / 2;
fill(w.begin(), w.end(), 0);
for (int i = l; i <= mid; i++){
w[parEdge[cand[i]]] = 1;
}
int chck = ask(w);
if (chck != base) r = mid;
else l = mid + 1;
}
return cand[r];
}
void find_pair(signed n, vector<signed> u, vector<signed> v, signed ap, signed bp) {
int a = (int)ap, b = (int) bp;
int m = u.size();
vector<signed> w(m, 0);
int base = ask(w);
int totLen = base / (int)a;
col.assign(m, 0);
colVertex.assign(n, 0);
dist.assign(n, 0);
parEdge.assign(n, 0);
adj.assign(n, vector<pair<int, int>>());
// binary search for the edge on
int l = 0, r = m-1;
while(l < r){
int mid = (l + r)/2;
// for l to mid, check if increase or not
for (int i = 0; i < m; i++){
w[i] = 0;
}
for (int i = l; i <= mid; i++){
w[i] = 1;
}
int chck = ask(w);
if (chck > totLen*a){
// move to r
r = mid;
// cout << "move r to " << mid << endl;
}
else{
l = mid+1;
// cout << "move l to " << mid+1 << endl;
}
}
// r is the edge confirmed to be in the path
int st1 = u[r], st2 = v[r];
for (int i = 0; i < m; i++){
adj[u[i]].push_back({v[i], i});
adj[v[i]].push_back({u[i], i});
}
// dfs to seperate into 2 subtrees
dfs(st1, 0, 1, st2);
dfs(st2, 0, 2, st1);
// all edges are either col 1, 2 or 0 (r)
int q1, q2;
// set 1 to a, 2 to b
for (int i = 0; i < m; i++){
if (col[i] == 1 || col[i] == 0){
w[i] = 0;
}
else{
w[i] = 1;
}
}
q1 = ask(w);
for (int i = 0; i < m; i++){
if (col[i] == 2 || col[i] == 0){
w[i] = 0;
}
else{
w[i] = 1;
}
}
q2 = ask(w);
int sum = (q1 + q2 - 2*a) / (a + b);
int dif = (q1 - q2) / (a - b);
int len1 = (sum + dif) / 2;
int len2 = (sum - dif) / 2;
int s = sub2(n, a, b, len1, 1, base);
int t = sub2(n, a, b, len2, 2, base);
answer(s, t);
}