#include <bits/stdc++.h>
#define vi vector<int>
#define vb vector<bool>
#define vl vector<ll>
#define vii vector<vector<int>>
#define vll vector<vector<ll>>
#define pi pair<int, int>
#define pl pair<ll, ll>
#define vpi vector<pair<int, int>>
#define vpl vector<pair<ll, ll>>
#define a first
#define b second
#define pb push_back
#define vm vector<mi>
#define vmm vector<vector<mi>>
using namespace std;
const int MOD = 998244353;
using ll = long long;
using big = __int128_t;
using ld = long double;
int uMin(int& a, int b){return a = min(a,b);}
int uMax(int& a, int b){return a = max(a,b);}
struct mi {
ll v;
explicit operator int() const { return v; }
mi() { v = 0; }
mi(ll _v) : v(_v % MOD) { v += (v < 0) * MOD; }
bool operator<(const mi& other) const {
return v < other.v;
}
friend std::ostream& operator<<(std::ostream& os, const mi& m) {
os << m.v;
return os;
}
};
mi &operator+=(mi &a, mi b) {
if ((a.v += b.v) >= MOD) a.v -= MOD;
return a;
}
mi &operator-=(mi &a, mi b) {
if ((a.v -= b.v) < 0) a.v += MOD;
return a;
}
mi operator+(mi a, mi b) { return a += b; }
mi operator-(mi a, mi b) { return a -= b; }
mi operator*(mi a, mi b) { return mi(a.v * b.v); }
mi &operator*=(mi &a, mi b) { return a = a * b; }
mi pow(mi a, ll p) {
assert(p >= 0);
return p == 0 ? 1 : pow(a * a, p / 2) * (p & 1 ? a : 1);
}
mi inv(mi a) {
assert(a.v != 0);
return pow(a, MOD - 2);
}
mi operator/(mi a, mi b) { return a * inv(b); }
bool operator==(mi a, mi b) {return a.v == b.v;}
bool operator!=(mi a, mi b) {return a.v != b.v;}
template <class T>
class Matrix {
private:
vector<std::vector<T>> data;
size_t size;
public:
Matrix(size_t n, T v = T()) : size(n), data(n, std::vector<T>(n, v)) {}
Matrix(const Matrix<T>& other) : size(other.size), data(other.data) {}
T& operator()(size_t i, size_t j) { return data[i][j]; }
const T& operator()(size_t i, size_t j) const { return data[i][j]; }
Matrix<T> operator+(const Matrix<T>& other) const {
Matrix<T> result(size);
for (size_t i = 0; i < size; ++i)
for (size_t j = 0; j < size; ++j)
result(i, j) = data[i][j] + other(i, j);
return result;
}
Matrix<T> operator*(const Matrix<T>& other) const {
Matrix<T> result(size, 0);
for (size_t i = 0; i < size; ++i)
for (size_t j = 0; j < size; ++j)
for (size_t k = 0; k < size; ++k)
result(i, k) += data[i][j] * other(j, k);
return result;
}
std::vector<T> operator*(const std::vector<T>& vec) const {
std::vector<T> result(size, 0);
for (size_t i = 0; i < size; ++i)
for (size_t j = 0; j < size; ++j)
result[j] += data[i][j] * vec[i];
return result;
}
Matrix<T> pow(ll e) const {
Matrix<T> result = identity(size);
Matrix<T> base = *this;
while(e) {
if(e&1) result = result * base;
base = base * base;
e>>=1;
}
return result;
}
bool operator==(const Matrix<T>& other) const {
for (size_t i = 0; i < size; ++i)
for (size_t j = 0; j < size; ++j)
if (data[i][j] != other(i, j))
return false;
return true;
}
static Matrix<T> identity(size_t n) {
Matrix<T> id(n, 0);
for (size_t i = 0; i < n; ++i)
id(i, i) = 1;
return id;
}
friend std::ostream& operator<<(std::ostream& os, const Matrix<T>& matrix) {
os << "[\n";
for (const auto& row : matrix.data) {
os << "\t[";
for(int i = 0; i<row.size(); i++)
os << row[i] << (i==row.size()-1?"]\n":", ");
}
os << "]\n";
return os;
}
};
template <class T> class BIT {
int N; vector<T> data;
public:
BIT(int _N){
N = _N+1;
data.resize(N);
}
void add(int p, T x) { for (p++;p<=N;p+=p&-p) data[p-1]+=x; }
T sum(int l, int r) { return sum(r)-(l==0?0:sum(l-1)); }
T sum(int r) { T s = 0; r++; for(;r;r-=r&-r)s+=data[r-1]; return s; }
int lower_bound(T sum) {
if (sum <= 0) return -1;
int pos = 0;
for (int pw = 1<<25; pw; pw >>= 1) {
int npos = pos+pw;
if (npos <= N && data[npos-1] < sum)
pos = npos, sum -= data[pos-1];
}
return pos;
}
};
template <class T> class RURQ {
BIT<T> bit1, bit2;
int sz;
RURQ(int size) : sz(size), bit1(size), bit2(size) {}
void add(int start, int end, T value) {
bit1.add(start, value);
bit2.add(start, value*(start - 1));
if(end != sz-1){
bit1.add(end+1, -value);
bit2.add(end+1, -value*end);
}
}
T pref(int index) {
if(index==-1) return 0;
return (bit1.sum(index))*index-bit2.sum(index);
}
T sum(int start, int end) {
return pref(end)-pref(start-1);
}
};
struct Line {
mutable ll k, m, p;
bool operator<(const Line& o) const { return k < o.k; }
bool operator<(ll x) const { return p < x; }
};
struct LineContainer : multiset<Line, less<>> {
// (for doubles, use inf = 1/.0, div(a,b) = a/b)
static const ll inf = LLONG_MAX;
ll div(ll a, ll b) { // floored division
return a / b - ((a ^ b) < 0 && a % b); }
bool isect(iterator x, iterator y) {
if (y == end()) return x->p = inf, 0;
if (x->k == y->k) x->p = x->m > y->m ? inf : -inf;
else x->p = div(y->m - x->m, x->k - y->k);
return x->p >= y->p;
}
void add(ll k, ll m) {
auto z = insert({k, m, 0}), y = z++, x = y;
while (isect(y, z)) z = erase(z);
if (x != begin() && isect(--x, y)) isect(x, y = erase(y));
while ((y = x) != begin() && (--x)->p >= y->p)
isect(x, erase(y));
}
ll queryMax(ll x) {
assert(!empty());
auto l = *lower_bound(x);
return l.k * x + l.m;
}
};
vl euclid(ll a, ll b) {
vl x = {1, 0, a};
vl y = {0, 1, b};
while(y[2]){
ll k = x[2] / y[2];
x[0]-=k*y[0];
x[1]-=k*y[1];
x[2]-=k*y[2];
swap(x, y);
}
return x; // x[0] * a + x[1] * b = x[2], x[2] = gcd(a, b)
}
pl modSolver(vpl& mods){
pl ans = {mods[0].a,mods[0].b};
for(int i = 1; i<mods.size(); i++){
vl g = euclid(ans.b, mods[i].b);
if((mods[i].a-ans.a)%g[2]) return {-1,-1};
ans.a+=(mods[i].a-ans.a)/g[2]*ans.b*g[0];
ans.b = ans.b/g[2]*mods[i].b;
ans.a%=ans.b;
ans.a+=ans.b;
ans.a%=ans.b;
}
return ans;
}
int sq(ll n){
ll a = 0;
for(int i = 31; i>=0; i--){
ll x = 1<<i|a;
if(x*x <= n) a = x;
}
return a;
}
int dGet(int a, vi& d){return d[a]<0?a:(d[a]=dGet(d[a],d));}
bool unite(int a, int b, vi& d){
a = dGet(a,d), b = dGet(b,d);
if(a == b) return 0;
if(d[a]>d[b]) swap(a,b);
d[a]+=d[b];
d[b] = a;
return 1;
}
int lg(int n){
int i = 0;
for(n>>=1; n; n>>=1) i++;
return i;
}
vector<int> manacher_odd(string s) {
int n = s.size();
s = "$" + s + "^";
vector<int> p(n + 2);
int l = 1, r = 1;
for(int i = 1; i <= n; i++) {
p[i] = max(0, min(r - i, p[l + (r - i)]));
while(s[i - p[i]] == s[i + p[i]]) {
p[i]++;
}
if(i + p[i] > r) {
l = i - p[i], r = i + p[i];
}
}
return vector<int>(begin(p) + 1, end(p) - 1);
}
vector<int> manacher(string s) {
string t;
for(auto c: s) {
t += string("#") + c;
}
auto res = manacher_odd(t + "#");
return vector<int>(begin(res) + 1, end(res) - 1);
}
template <class T> struct ST {
vector<T> tree;
vi left,right;
int n;
ST(int N) : n(N){
tree.pb(0);
left.pb(-1);
right.pb(-1);
}
void upd(int l, int r, int ind, int s, T x, bool add){
if(l == r){
if(add) tree[ind]+=x;
else tree[ind] = x;
return;
}
int m = l+((r-l)>>1);
if(s<=m){
if(left[ind] == -1){
left[ind] = tree.size();
tree.pb(0);
left.pb(-1);
right.pb(-1);
}
upd(l,m,left[ind],s,x,add);
}else{
if(right[ind] == -1){
right[ind] = tree.size();
tree.pb(0);
left.pb(-1);
right.pb(-1);
}
upd(m+1,r,right[ind],s,x,add);
}
if(left[ind] == -1) tree[ind] = tree[right[ind]];
else if(right[ind] == -1) tree[ind] = tree[left[ind]];
else tree[ind] = tree[left[ind]]+tree[right[ind]];
}
void add(int s, T x){
upd(0,n-1,0,s,x,1);
}
void set(int s, T x){
upd(0,n-1,0,s,x,0);
}
T sum(int cL, int cR, int ind, int l, int r){
if(cL == l && cR == r) return tree[ind];
int m = cL+((cR-cL)>>1);
if(r<=m) return sum(cL,m,left[ind],l,r);
if(l>m) return sum(m+1,cR,right[ind],l,r);
if(left[ind] == -1) return sum(m+1,cR,right[ind],m+1,r);
if(right[ind] == -1) return sum(cL,m,left[ind],l,m);
return sum(cL,m,left[ind],l,m)+sum(m+1,cR,right[ind],m+1,r);
}
T sum(int l, int r){
return sum(0,n-1,0,l,r);
}
};
ll square(ll a) {
return a * a;
}
ll norm(const pl& p) {
return square(p.a) + square(p.b);
}
ll cross(const pl& a, const pl& b) {
return a.a * b.b - a.b * b.a;
}
ll cross(const pl& p, const pl& a, const pl& b) {
return cross({a.a-p.a,a.b-p.b},{b.a-p.a,b.b-p.b});
}
vi hullInd(const vpl& v) {
int ind = min_element(v.begin(), v.end()) - v.begin();
vi cand, hull{ind};
for (int i = 0; i<v.size(); i++)
if (v[i] != v[ind])
cand.push_back(i);
sort(cand.begin(), cand.end(), [&](int a, int b) {
pl x = {v[a].a-v[ind].a,v[a].b-v[ind].b}, y = {v[b].a-v[ind].a,v[b].b-v[ind].b};
ll t = cross(x, y);
return t != 0 ? t > 0 : norm(x) < norm(y);
});
for (int c : cand) {
while (hull.size() > 1 && cross(v[hull[hull.size() - 2]], v[hull.back()], v[c]) <= 0)
hull.pop_back();
hull.push_back(c);
}
return hull;
}
typedef unsigned long long ull;
ull modmul(ull a, ull b, ull M) {
ll ret = a * b - M * ull(1.L / M * a * b);
return ret + M * (ret < 0) - M * (ret >= (ll)M);
}
ull modpow(ull b, ull e, ull mod) {
ull ans = 1;
for (; e; b = modmul(b, b, mod), e /= 2)
if (e & 1) ans = modmul(ans, b, mod);
return ans;
}
bool isPrime(ull n) {
if (n < 2 || n % 6 % 4 != 1) return (n | 1) == 3;
ull A[] = {2, 325, 9375, 28178, 450775, 9780504, 1795265022},
s = __builtin_ctzll(n-1), d = n >> s;
for (ull a : A) {
ull p = modpow(a%n, d, n), i = s;
while (p != 1 && p != n - 1 && a % n && i--)
p = modmul(p, p, n);
if (p != n-1 && i != s) return 0;
}
return 1;
}
ull pollard(ull n) {
ull x = 0, y = 0, t = 30, prd = 2, i = 1, q;
auto f = [&](ull x) { return modmul(x, x, n) + i; };
while (t++ % 40 || __gcd(prd, n) == 1) {
if (x == y) x = ++i, y = f(x);
if ((q = modmul(prd, max(x,y) - min(x,y), n))) prd = q;
x = f(x), y = f(f(y));
}
return __gcd(prd, n);
}
vector<ull> factor(ull n) {
if (n == 1) return {};
if (isPrime(n)) return {n};
ull x = pollard(n);
auto l = factor(x), r = factor(n / x);
l.insert(l.end(),r.begin(),r.end());
return l;
}
template<class F>
std::vector<int> smawck(F f, const std::vector<int> &rows, const std::vector<int> &cols) {
std::vector<int> ans(rows.size(), -1);
if((int) std::max(rows.size(), cols.size()) <= 2) {
for(int i = 0; i < (int) rows.size(); i++) {
for(auto j : cols) {
if(ans[i] == -1 || f(rows[i], ans[i], j)) {
ans[i] = j;
}
}
}
} else if(rows.size() < cols.size()) {
std::vector<int> st;
for(int j : cols) {
while(1) {
if(st.empty()) {
st.push_back(j);
break;
} else if(f(rows[(int) st.size() - 1], st.back(), j)) {
st.pop_back();
} else if(st.size() < rows.size()) {
st.push_back(j);
break;
} else {
break;
}
}
}
ans = smawck(f, rows, st);
} else {
std::vector<int> newRows;
for(int i = 1; i < (int) rows.size(); i += 2) {
newRows.push_back(rows[i]);
}
auto otherAns = smawck(f, newRows, cols);
for(int i = 0; i < (int) newRows.size(); i++) {
ans[2*i+1] = otherAns[i];
}
for(int i = 0, l = 0, r = 0; i < (int) rows.size(); i += 2) {
if(i+1 == (int) rows.size()) r = (int) cols.size();
while(r < (int) cols.size() && cols[r] <= ans[i+1]) r++;
ans[i] = cols[l++];
for(; l < r; l++) {
if(f(rows[i], ans[i], cols[l])) {
ans[i] = cols[l];
}
}
l--;
}
}
return ans;
}
template<class F>
std::vector<int> smawck(F f, int n, int m) {
std::vector<int> rows(n), cols(m);
for(int i = 0; i < n; i++) rows[i] = i;
for(int i = 0; i < m; i++) cols[i] = i;
return smawck(f, rows, cols);
}
template<class T>
std::vector<T> MaxConvolutionWithConvexShape(std::vector<T> anyShape, const std::vector<T> &convexShape) {
if((int) convexShape.size() <= 1) return anyShape;
if(anyShape.empty()) anyShape.push_back(0);
int n = (int) anyShape.size(), m = (int) convexShape.size();
auto function = [&](int i, int j) {
return anyShape[j] + convexShape[i-j];
};
auto comparator = [&](int i, int j, int k) {
if(i < k) return false;
if(i - j >= m) return true;
return function(i, j) <= function(i, k);
};
const std::vector<int> best = smawck(comparator, n + m - 1, n);
std::vector<T> ans(n + m - 1);
for(int i = 0; i < n + m - 1; i++) {
ans[i] = function(i, best[i]);
}
return ans;
}
int kadane(vi& a, int l, int r){
int sum = 0, best = 0;
for(int i = l; i<=r; i++){
sum = max(sum,0)+a[i];
uMax(best,sum);
}
return best;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n,a,b; cin >> n >> a >> b;
a--,b--;
vii g(n);
for(int i = 1; i<n; i++){
int x,y; cin >> x >> y;
g[--x].pb(--y);
g[y].pb(x);
}
vi depth(n);
function<void(int,int)> dfs1;
dfs1 = [&](int s, int p){
if(p != -1) depth[s] = depth[p]+1;
for(int i: g[s]) if(i != p)
dfs1(i,s);
};
dfs1(a,-1);
vi path{b};
while(path[path.size()-1] != a)
for(int i: g[path[path.size()-1]])
if(depth[i]+1 == depth[path[path.size()-1]]){
path.pb(i);
break;
}
function<int(int,int,int)> dfs2;
dfs2 = [&](int s, int p, int stop){
vi next;
for(int i: g[s]) if(i != p && i != stop)
next.pb(dfs2(i,s,stop));
sort(next.rbegin(),next.rend());
int ans = 0;
for(int i = 0; i<next.size(); i++) uMax(ans,next[i]+i);
return ans+1;
};
int lo = 0, hi = path.size()-2;
while(lo<hi-1){
int m = (lo+hi)>>1;
int a1 = dfs2(b,-1,path[m+1]);
int a2 = dfs2(a,-1,path[m]);
if(a1<a2) lo = m;
else hi = m;
}
cout << min(max(dfs2(b,-1,path[lo+1]),dfs2(a,-1,path[lo])),max(dfs2(b,-1,path[hi+1]),dfs2(a,-1,path[hi])))-1 << "\n";
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |