#include "stations.h"
#include <vector>
#include <algorithm>
using namespace std;
typedef vector<int> vi;
typedef vector<vi> vvi;
int current_given = 0;
void dfs(int par, int node, int depth, vi &label, vvi &adj) {
if (depth % 2 == 0) label[node] = current_given++;
for (int u : adj[node]) {
if (u != par) dfs(node, u, depth+1, label, adj);
}
if (depth % 2) label[node] = current_given++;
}
vi label(int n, int k,vi u, vi v) {
vvi adj(n);
vi label(n);
for (int i = 0; i < n-1; ++i) adj[u[i]].push_back(v[i]), adj[v[i]].push_back(u[i]);
dfs(0, 0, 0, label, adj);
return label;
}
int solveEven(int s, int t, vi c) {
int M = c.size();
if (s) {
if (t > c[M-2] || t < s) return c[M-1];
}
for (int i = M-1; i >= 1; --i) if (c[i-1] < t) return c[i];
return c[0];
}
int solveOdd(int s, int t, vi c) {
int start = c[1], end = s;
if (t < start || t > end) return c[0];
for (int i = 1; i<c.size()-1; ++i) {
if (c[i+1] > t) return c[i];
}
return c.back();
}
int find_next_station(int s, int t, vi c) {
for (int i : c) if (i == t) return t;
if (c.size() == 1) return c[0];
// now assume not neighbour and not a leaf
sort(c.begin(), c.end());
if (s > c[0]) return solveOdd(s, t, c);
return solveEven(s, t, c);
}