#include "highway.h"
#include <bits/stdc++.h>
#define ll long long
using namespace std;
void find_pair(int n, std::vector<int> u, std::vector<int> v, int A, int B) {
ll a = A;
ll b = B;
ll m = u.size();
if(m!=n-1) assert(0);
vector<vector<array<ll, 2>>> adj(n);
for(ll i = 0; i<m; i++){
adj[u[i]].push_back({v[i], i});
adj[v[i]].push_back({u[i], i});
}
vector<ll> h(n);
ll h_max = 0;
vector<ll> par(n, -1);
auto dfs = [&](auto&& dfs, ll v, ll p)->void{
par[v] = p;
h_max = max(h[v], h_max);
for(auto [u, i]: adj[v]){
if(u==p) continue;
h[u] = h[v]+1;
dfs(dfs, u, v);
}
return;
};
h[0] = 0;
dfs(dfs, 0, -1);
vector<vector<ll>> layers(h_max+1);
for(ll i = 0; i<n; i++) layers[h[i]].push_back(i);
vector<int> w(m, 0);
ll d = ask(w)/a;
ll l = 0; ll r = h_max;
while(r-l>1){
ll mid = (l+r)/2;
for(ll i = 0; i<m; i++) w[i] = 0;
for(ll i = h_max-1; i>=mid; i--){
for(ll j: layers[i]){
for(auto [u, idx]: adj[j]){
if(u == par[j]) continue;
w[idx] = 1;
}
}
}
ll res = ask(w);
if(res > d*a) l = mid;
else r = mid;
}
ll height_lowest = l+1;
l = -1; r = layers[height_lowest].size()-1;
while(r-l>1){
ll mid = (l+r+1)/2;
for(ll i = 0; i<m; i++) w[i] = 0;
for(ll i = 0; i<=mid; i++){
ll j = layers[height_lowest][i];
for(auto [u, idx]: adj[j]){
if(u == par[j]) w[idx] = 1;
}
}
ll res = ask(w);
if(res > d*a) r = mid;
else l = mid;
}
ll s = layers[height_lowest][r];
//swap(layers[height_lowest][r], layers[height_lowest][0]);
vector<ll> leaves;
leaves.push_back(s);
for(ll i = 0; i<n; i++){
if(i != s && i != 0 && (h[i] == height_lowest || (adj[i].size() == 1 && h[i] < height_lowest))) leaves.push_back(i);
}
vector<ll> col(m, 1e18);
vector<ll> col_nd(n, 1e18);
col_nd[0] =0;
for(ll i = 0; i<leaves.size(); i++){
ll nd = leaves[i];
while(nd!=-1){
if(col_nd[nd] != 1e18) break;
col_nd[nd] = i;
nd = par[nd];
}
}
for(ll v = 0; v<n; v++){
for(auto [u, idx]: adj[v]){
if(u==par[v])continue;
col[idx] = col_nd[u];
}
}
l = -1; r = leaves.size()-1;
while(r-l>1){
ll mid = (l+r+1)/2;
for(ll i =0; i<m; i++){
if(col[i] > mid) w[i] =0;
else w[i] = 1;
}
ll res = ask(w);
if(res == d*b) r = mid;
else l = mid;
}
if(r == 0){
ll curr_d = 0;
ll nd = s;
while(curr_d < d){
nd = par[nd];
curr_d++;
}
answer(s, nd);
}
else{
ll lca = leaves[r];
while(col_nd[lca] != 0) lca = par[lca];
ll t = leaves[r];
ll curr_d = height_lowest - h[lca]+h[leaves[r]]-h[lca];
while(curr_d > d){
t = par[t];
curr_d--;
}
answer(s, t);
}
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |