Submission #973266

#TimeUsernameProblemLanguageResultExecution timeMemory
973266efedmrlrMousetrap (CEOI17_mousetrap)C++17
45 / 100
778 ms124804 KiB
// #pragma GCC optimize("O3,Ofast,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include <bits/stdc++.h>

using namespace std;


#define int long long int
#define MP make_pair
#define pb push_back
#define REP(i,n) for(int i = 0; (i) < (n); (i)++)
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()


void fastio() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
}


const double EPS = 0.00001;
const int INF = 1e9+500;
const int N = 1e6 + 5;
const int ALPH = 26;
const int LGN = 25;
constexpr int MOD = 1e9+7;
int n,t,m;

vector<int> dp(N, 0);
vector<vector<int> > adj(N, vector<int>());
vector<array<int, 2> > vs;
void dfs(int node, int par) {
    for(auto c : adj[node]) {
        if(c == par) continue;
        dfs(c, node);
    }
    array<int, 2> mx = {0, 0};
    for(auto c : adj[node]) {
        if(c == par) continue;
        if(dp[c] > mx[0]) {
            swap(mx[0], mx[1]);
            mx[0] = dp[c];
        }
        else if(dp[c] > mx[1]) {
            mx[1] = dp[c];
        }
    }
    dp[node] = mx[1] + (int)adj[node].size() - 1; 

}

int dfs2(int node, int par, int dist = 0) {
    if(node == t) {
        return node;
    }
    int ret = 0;
    for(auto c : adj[node]) {
        if(c == par) continue;
        ret = dfs2(c, node, dist + 1);
        if(ret) {
            break;
        }
    }
    if(!ret) return 0;
    for(auto c : adj[node]) {
        if(c == par || c == ret) continue;
        dfs(c, node);
        vs.pb({dp[c], dist});
        // cout << "c:" << c << " dp:" << dp[c] << " dist:" << dist << "\n";
    }
    return node;
}

inline void solve() {
    cin>>n>>t>>m;
    REP(i, n - 1) {
        int u, v;
        cin >> u >> v;
        adj[u].pb(v);
        adj[v].pb(u);
    }
    dfs2(m, 0);
    sort(rall(vs));
    int ans = (int)vs.size();
    set<int> slot;
    for(int i = 0; i < n; i++) {
        slot.insert(i);
    }
    for(auto &c : vs) {
        auto itr = slot.upper_bound(c[1]);
        if(itr == slot.begin()) {
            ans += c[0];
            break;
        }
        itr--;
        slot.erase(itr);
    }
    cout << ans << "\n";
}
 
signed main() {

    fastio();
    int test = 1;
    //cin>>test;
    while(test--) {
        solve();
    }
    
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...