제출 #1129437

#제출 시각아이디문제언어결과실행 시간메모리
1129437CutSandstoneTorrent (COI16_torrent)C++20
100 / 100
389 ms25312 KiB
#include <bits/stdc++.h>
#define vi vector<int>
#define vb vector<bool>
#define vl vector<ll>
#define vii vector<vector<int>>
#define vll vector<vector<ll>>
#define pi pair<int, int>
#define pl pair<ll, ll>
#define vpi vector<pair<int, int>>
#define vpl vector<pair<ll, ll>>
#define a first 
#define b second
#define pb push_back
#define vm vector<mi>
#define vmm vector<vector<mi>>
using namespace std;
const int MOD = 998244353;
using ll = long long;
using big = __int128_t;
using ld = long double;
int uMin(int& a, int b){return a = min(a,b);}
int uMax(int& a, int b){return a = max(a,b);}
struct mi {
  ll v;
  explicit operator int() const { return v; }
  mi() { v = 0; }
  mi(ll _v) : v(_v % MOD) { v += (v < 0) * MOD; }
  bool operator<(const mi& other) const {
    return v < other.v;
  }
  friend std::ostream& operator<<(std::ostream& os, const mi& m) {
      os << m.v;
      return os;
  }
};
mi &operator+=(mi &a, mi b) {
  if ((a.v += b.v) >= MOD) a.v -= MOD;
  return a;
}
mi &operator-=(mi &a, mi b) {
  if ((a.v -= b.v) < 0) a.v += MOD;
  return a;
}
mi operator+(mi a, mi b) { return a += b; }
mi operator-(mi a, mi b) { return a -= b; }
mi operator*(mi a, mi b) { return mi(a.v * b.v); }
mi &operator*=(mi &a, mi b) { return a = a * b; }
mi pow(mi a, ll p) {
  assert(p >= 0);
  return p == 0 ? 1 : pow(a * a, p / 2) * (p & 1 ? a : 1);
}
mi inv(mi a) {
  assert(a.v != 0);
  return pow(a, MOD - 2);
}
mi operator/(mi a, mi b) { return a * inv(b); }
bool operator==(mi a, mi b) {return a.v == b.v;}
bool operator!=(mi a, mi b) {return a.v != b.v;}
template <class T>
class Matrix {
private:
    vector<std::vector<T>> data;
    size_t size;
public:
    Matrix(size_t n, T v = T()) : size(n), data(n, std::vector<T>(n, v)) {}
    Matrix(const Matrix<T>& other) : size(other.size), data(other.data) {}
    T& operator()(size_t i, size_t j) { return data[i][j]; }
    const T& operator()(size_t i, size_t j) const { return data[i][j]; }
    Matrix<T> operator+(const Matrix<T>& other) const {
        Matrix<T> result(size);
        for (size_t i = 0; i < size; ++i)
            for (size_t j = 0; j < size; ++j)
                result(i, j) = data[i][j] + other(i, j);
        return result;
    }
    Matrix<T> operator*(const Matrix<T>& other) const {
        Matrix<T> result(size, 0);
        for (size_t i = 0; i < size; ++i)
            for (size_t j = 0; j < size; ++j)
                for (size_t k = 0; k < size; ++k)
                    result(i, k) += data[i][j] * other(j, k);
        return result;
    }
    std::vector<T> operator*(const std::vector<T>& vec) const {
        std::vector<T> result(size, 0);
        for (size_t i = 0; i < size; ++i)
            for (size_t j = 0; j < size; ++j)
                result[j] += data[i][j] * vec[i];
        return result;
    }
    Matrix<T> pow(ll e) const {
        Matrix<T> result = identity(size);
        Matrix<T> base = *this;
        while(e) {
            if(e&1) result = result * base;
            base = base * base;
            e>>=1;
        }
        return result;
    }
    bool operator==(const Matrix<T>& other) const {
        for (size_t i = 0; i < size; ++i)
            for (size_t j = 0; j < size; ++j)
                if (data[i][j] != other(i, j))
                    return false;
        return true;
    }
    static Matrix<T> identity(size_t n) {
        Matrix<T> id(n, 0);
        for (size_t i = 0; i < n; ++i)
            id(i, i) = 1;
        return id;
    }
    friend std::ostream& operator<<(std::ostream& os, const Matrix<T>& matrix) {
        os << "[\n";
        for (const auto& row : matrix.data) {
            os << "\t[";
            for(int i = 0; i<row.size(); i++)
                os << row[i] << (i==row.size()-1?"]\n":", ");
        }
        os << "]\n";
        return os;
    }
};
template <class T> class BIT {
    int N; vector<T> data;
    public:
    BIT(int _N){
        N = _N+1;
        data.resize(N);
    }
    void add(int p, T x) { for (p++;p<=N;p+=p&-p) data[p-1]+=x; }
    T sum(int l, int r) { return sum(r)-(l==0?0:sum(l-1)); }
    T sum(int r) { T s = 0; r++; for(;r;r-=r&-r)s+=data[r-1]; return s; }
    int lower_bound(T sum) {
        if (sum <= 0) return -1;
        int pos = 0;
        for (int pw = 1<<25; pw; pw >>= 1) {
            int npos = pos+pw;
            if (npos <= N && data[npos-1] < sum)
                pos = npos, sum -= data[pos-1];
        }
        return pos;
    }
};
template <class T> class RURQ {
    BIT<T> bit1, bit2;
    int sz;
    RURQ(int size) : sz(size), bit1(size), bit2(size) {}
    void add(int start, int end, T value) {
        bit1.add(start, value);
        bit2.add(start, value*(start - 1));
        if(end != sz-1){
            bit1.add(end+1, -value);
            bit2.add(end+1, -value*end);
        }
    }
    T pref(int index) {
        if(index==-1) return 0;
        return (bit1.sum(index))*index-bit2.sum(index);
    }
    T sum(int start, int end) {
        return pref(end)-pref(start-1);
    }
};
struct Line {
	mutable ll k, m, p;
	bool operator<(const Line& o) const { return k < o.k; }
	bool operator<(ll x) const { return p < x; }
};

struct LineContainer : multiset<Line, less<>> {
	// (for doubles, use inf = 1/.0, div(a,b) = a/b)
	static const ll inf = LLONG_MAX;
	ll div(ll a, ll b) { // floored division
		return a / b - ((a ^ b) < 0 && a % b); }
	bool isect(iterator x, iterator y) {
		if (y == end()) return x->p = inf, 0;
		if (x->k == y->k) x->p = x->m > y->m ? inf : -inf;
		else x->p = div(y->m - x->m, x->k - y->k);
		return x->p >= y->p;
	}
	void add(ll k, ll m) {
		auto z = insert({k, m, 0}), y = z++, x = y;
		while (isect(y, z)) z = erase(z);
		if (x != begin() && isect(--x, y)) isect(x, y = erase(y));
		while ((y = x) != begin() && (--x)->p >= y->p)
			isect(x, erase(y));
	}
	ll queryMax(ll x) {
		assert(!empty());
		auto l = *lower_bound(x);
		return l.k * x + l.m;
	}
};
vl euclid(ll a, ll b) {
    vl x = {1, 0, a};
    vl y = {0, 1, b};
    while(y[2]){
        ll k = x[2] / y[2];
        x[0]-=k*y[0];
        x[1]-=k*y[1];
        x[2]-=k*y[2];
        swap(x, y);
    }
    return x;  // x[0] * a + x[1] * b = x[2], x[2] = gcd(a, b)
}
pl modSolver(vpl& mods){
    pl ans = {mods[0].a,mods[0].b};
    for(int i = 1; i<mods.size(); i++){
        vl g = euclid(ans.b, mods[i].b);
        if((mods[i].a-ans.a)%g[2]) return {-1,-1};
        ans.a+=(mods[i].a-ans.a)/g[2]*ans.b*g[0];
        ans.b = ans.b/g[2]*mods[i].b;
        ans.a%=ans.b;
        ans.a+=ans.b;
        ans.a%=ans.b;
    }
    return ans;
}
int sq(ll n){
    ll a = 0;
    for(int i = 31; i>=0; i--){
        ll x = 1<<i|a;
        if(x*x <= n) a = x;
    }
    return a;
}
int dGet(int a, vi& d){return d[a]<0?a:(d[a]=dGet(d[a],d));}
bool unite(int a, int b, vi& d){
    a = dGet(a,d), b = dGet(b,d);
    if(a == b) return 0;
    if(d[a]>d[b]) swap(a,b);
    d[a]+=d[b];
    d[b] = a;
    return 1;
}
int lg(int n){
    int i = 0;
    for(n>>=1; n; n>>=1) i++;
    return i;
}
vector<int> manacher_odd(string s) {
    int n = s.size();
    s = "$" + s + "^";
    vector<int> p(n + 2);
    int l = 1, r = 1;
    for(int i = 1; i <= n; i++) {
        p[i] = max(0, min(r - i, p[l + (r - i)]));
        while(s[i - p[i]] == s[i + p[i]]) {
            p[i]++;
        }
        if(i + p[i] > r) {
            l = i - p[i], r = i + p[i];
        }
    }
    return vector<int>(begin(p) + 1, end(p) - 1);
}

vector<int> manacher(string s) {
    string t;
    for(auto c: s) {
        t += string("#") + c;
    }
    auto res = manacher_odd(t + "#");
    return vector<int>(begin(res) + 1, end(res) - 1);
}
template <class T> struct ST {
    vector<T> tree;
    vi left,right;
    int n;
    ST(int N) : n(N){
        tree.pb(0);
        left.pb(-1);
        right.pb(-1);
    }
    void upd(int l, int r, int ind, int s, T x, bool add){
        if(l == r){
            if(add) tree[ind]+=x;
            else tree[ind] = x;
            return;
        }
        int m = l+((r-l)>>1);
        if(s<=m){
            if(left[ind] == -1){
                left[ind] = tree.size();
                tree.pb(0);
                left.pb(-1);
                right.pb(-1);
            }
            upd(l,m,left[ind],s,x,add);
        }else{
            if(right[ind] == -1){
                right[ind] = tree.size();
                tree.pb(0);
                left.pb(-1);
                right.pb(-1);
            }
            upd(m+1,r,right[ind],s,x,add);
        }
        if(left[ind] == -1) tree[ind] = tree[right[ind]];
        else if(right[ind] == -1) tree[ind] = tree[left[ind]];
        else tree[ind] = tree[left[ind]]+tree[right[ind]];
    }
    void add(int s, T x){
        upd(0,n-1,0,s,x,1);
    }
    void set(int s, T x){
        upd(0,n-1,0,s,x,0);
    }
    T sum(int cL, int cR, int ind, int l, int r){
        if(cL == l && cR == r) return tree[ind];
        int m = cL+((cR-cL)>>1);
        if(r<=m) return sum(cL,m,left[ind],l,r);
        if(l>m) return sum(m+1,cR,right[ind],l,r);
        if(left[ind] == -1) return sum(m+1,cR,right[ind],m+1,r);
        if(right[ind] == -1) return sum(cL,m,left[ind],l,m);
        return sum(cL,m,left[ind],l,m)+sum(m+1,cR,right[ind],m+1,r);
    }
    T sum(int l, int r){
        return sum(0,n-1,0,l,r);
    }
};
ll square(ll a) {
    return a * a;
}
ll norm(const pl& p) {
    return square(p.a) + square(p.b);
}
ll cross(const pl& a, const pl& b) {
    return a.a * b.b - a.b * b.a;
}
ll cross(const pl& p, const pl& a, const pl& b) {
    return cross({a.a-p.a,a.b-p.b},{b.a-p.a,b.b-p.b});
}
vi hullInd(const vpl& v) {
    int ind = min_element(v.begin(), v.end()) - v.begin();
    vi cand, hull{ind};

    for (int i = 0; i<v.size(); i++)
        if (v[i] != v[ind])
            cand.push_back(i);
    sort(cand.begin(), cand.end(), [&](int a, int b) {
        pl x = {v[a].a-v[ind].a,v[a].b-v[ind].b}, y = {v[b].a-v[ind].a,v[b].b-v[ind].b};
        ll t = cross(x, y);
        return t != 0 ? t > 0 : norm(x) < norm(y);
    });

    for (int c : cand) {
        while (hull.size() > 1 && cross(v[hull[hull.size() - 2]], v[hull.back()], v[c]) <= 0)
            hull.pop_back();
        hull.push_back(c);
    }

    return hull;
}
typedef unsigned long long ull;
ull modmul(ull a, ull b, ull M) {
	ll ret = a * b - M * ull(1.L / M * a * b);
	return ret + M * (ret < 0) - M * (ret >= (ll)M);
}
ull modpow(ull b, ull e, ull mod) {
	ull ans = 1;
	for (; e; b = modmul(b, b, mod), e /= 2)
		if (e & 1) ans = modmul(ans, b, mod);
	return ans;
}
bool isPrime(ull n) {
	if (n < 2 || n % 6 % 4 != 1) return (n | 1) == 3;
	ull A[] = {2, 325, 9375, 28178, 450775, 9780504, 1795265022},
	    s = __builtin_ctzll(n-1), d = n >> s;
	for (ull a : A) {
		ull p = modpow(a%n, d, n), i = s;
		while (p != 1 && p != n - 1 && a % n && i--)
			p = modmul(p, p, n);
		if (p != n-1 && i != s) return 0;
	}
	return 1;
}
ull pollard(ull n) {
	ull x = 0, y = 0, t = 30, prd = 2, i = 1, q;
	auto f = [&](ull x) { return modmul(x, x, n) + i; };
	while (t++ % 40 || __gcd(prd, n) == 1) {
		if (x == y) x = ++i, y = f(x);
		if ((q = modmul(prd, max(x,y) - min(x,y), n))) prd = q;
		x = f(x), y = f(f(y));
	}
	return __gcd(prd, n);
}
vector<ull> factor(ull n) {
	if (n == 1) return {};
	if (isPrime(n)) return {n};
	ull x = pollard(n);
	auto l = factor(x), r = factor(n / x);
	l.insert(l.end(),r.begin(),r.end());
	return l;
}
template<class F>
std::vector<int> smawck(F f, const std::vector<int> &rows, const std::vector<int> &cols) {
    std::vector<int> ans(rows.size(), -1);
    if((int) std::max(rows.size(), cols.size()) <= 2) {
        for(int i = 0; i < (int) rows.size(); i++) {
            for(auto j : cols) {
                if(ans[i] == -1 || f(rows[i], ans[i], j)) {
                    ans[i] = j;
                }
            }
        }
    } else if(rows.size() < cols.size()) {
        std::vector<int> st;
        for(int j : cols) {
            while(1) {
                if(st.empty()) {
                    st.push_back(j);
                    break;
                } else if(f(rows[(int) st.size() - 1], st.back(), j)) {
                    st.pop_back();
                } else if(st.size() < rows.size()) {
                    st.push_back(j);
                    break;
                } else {
                    break;
                }
            }
        }
        ans = smawck(f, rows, st);
    } else {
        std::vector<int> newRows;
        for(int i = 1; i < (int) rows.size(); i += 2) {
            newRows.push_back(rows[i]);
        }
        auto otherAns = smawck(f, newRows, cols);
        for(int i = 0; i < (int) newRows.size(); i++) {
            ans[2*i+1] = otherAns[i];
        }
        for(int i = 0, l = 0, r = 0; i < (int) rows.size(); i += 2) {
            if(i+1 == (int) rows.size()) r = (int) cols.size();
            while(r < (int) cols.size() && cols[r] <= ans[i+1]) r++;
            ans[i] = cols[l++];
            for(; l < r; l++) {
                if(f(rows[i], ans[i], cols[l])) {
                    ans[i] = cols[l];
                }
            }
            l--;
        }
    }
    return ans;
}
template<class F>
std::vector<int> smawck(F f, int n, int m) {
    std::vector<int> rows(n), cols(m);
    for(int i = 0; i < n; i++) rows[i] = i;
    for(int i = 0; i < m; i++) cols[i] = i;
    return smawck(f, rows, cols);
}

template<class T>
std::vector<T> MaxConvolutionWithConvexShape(std::vector<T> anyShape, const std::vector<T> &convexShape) {
    if((int) convexShape.size() <= 1) return anyShape;
    if(anyShape.empty()) anyShape.push_back(0);
    int n = (int) anyShape.size(), m = (int) convexShape.size();
    auto function = [&](int i, int j) {
        return anyShape[j] + convexShape[i-j];
    };
    auto comparator = [&](int i, int j, int k) {
        if(i < k) return false;
        if(i - j >= m) return true;
        return function(i, j) <= function(i, k);
    };
    const std::vector<int> best = smawck(comparator, n + m - 1, n);
    std::vector<T> ans(n + m - 1);
    for(int i = 0; i < n + m - 1; i++) {
        ans[i] = function(i, best[i]);
    }
    return ans;
}
int kadane(vi& a, int l, int r){
    int sum = 0, best = 0;
    for(int i = l; i<=r; i++){
        sum = max(sum,0)+a[i];
        uMax(best,sum);
    }
    return best;
}
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n,a,b; cin >> n >> a >> b;
    a--,b--;
    vii g(n);
    for(int i = 1; i<n; i++){
        int x,y; cin >> x >> y;
        g[--x].pb(--y);
        g[y].pb(x);
    }
    vi depth(n);
    function<void(int,int)> dfs1;
    dfs1 = [&](int s, int p){
        if(p != -1) depth[s] = depth[p]+1;
        for(int i: g[s]) if(i != p)
            dfs1(i,s);
    };
    dfs1(a,-1);
    vi path{b};
    while(path[path.size()-1] != a)
        for(int i: g[path[path.size()-1]])
            if(depth[i]+1 == depth[path[path.size()-1]]){
                path.pb(i);
                break;
            }
    function<int(int,int,int)> dfs2;
    dfs2 = [&](int s, int p, int stop){
        vi next;
        for(int i: g[s]) if(i != p && i != stop)
            next.pb(dfs2(i,s,stop));
        sort(next.rbegin(),next.rend());
        int ans = 0;
        for(int i = 0; i<next.size(); i++) uMax(ans,next[i]+i);
        return ans+1;
    };
    int lo = 0, hi = path.size()-2;
    while(lo<hi-1){
        int m = (lo+hi)>>1;
        int a1 = dfs2(b,-1,path[m+1]);
        int a2 = dfs2(a,-1,path[m]);
        if(a1<a2) lo = m;
        else hi = m;
    }
    cout << min(max(dfs2(b,-1,path[lo+1]),dfs2(a,-1,path[lo])),max(dfs2(b,-1,path[hi+1]),dfs2(a,-1,path[hi])))-1 << "\n";
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...