Submission #1240082

#TimeUsernameProblemLanguageResultExecution timeMemory
1240082vietbachleonkroos2326Torrent (COI16_torrent)C++20
69 / 100
610 ms42664 KiB
#include <bits/stdc++.h>
using namespace std;

const int maxn = 1e6;
int n, r1, r2;
vector<int> ke[maxn + 5];
vector<int> path;
int trace_par[maxn + 5];
int dp[maxn + 5];

void dfs_trace(int u, int p) {
    trace_par[u] = p;
    for (int v : ke[u]) {
        if (v == p) continue;
        dfs_trace(v, u);
    }
}

void findpath() {
    path.clear();
    dfs_trace(r1, 0);
    int u = r2;
    while (u != 0) {
        path.push_back(u);
        if (u == r1) break;
        u = trace_par[u];
    }
    reverse(path.begin(), path.end());
}

void dfs_compute(int u, int p, int blocked1, int blocked2) {
    vector<int> children;
    for (int v : ke[u]) {
        if (v == p) continue;
        if (u == blocked1 && v == blocked2) continue;
        if (u == blocked2 && v == blocked1) continue;
        dfs_compute(v, u, blocked1, blocked2);
        children.push_back(v);
    }
    sort(children.begin(), children.end(), [](int a, int b) {
        return dp[a] > dp[b];
    });
    for (int i = 0; i < children.size(); i++) {
        dp[u] = max(dp[u], dp[children[i]] + i + 1);
    }
}

int calc_split(int mid) {
    if (mid < 0 || mid + 1 >= path.size()) return INT_MAX;
    fill(dp, dp + n + 1, 0);
    int blocked1 = path[mid], blocked2 = path[mid + 1];
    dfs_compute(r1, 0, blocked1, blocked2);
    dfs_compute(r2, 0, blocked1, blocked2);
    return max(dp[r1], dp[r2]);
}

void solve() {
    cin >> n >> r1 >> r2;
    for (int i = 1; i <= n; i++) ke[i].clear();
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        ke[u].push_back(v);
        ke[v].push_back(u);
    }
    findpath();

    int L = 0, R = path.size() - 2;
    int ans = INT_MAX;
    
    
    while (R - L > 4) {
        int m1 = L + (R - L)/3;
        int m2 = R - (R - L)/3;
        int f1 = calc_split(m1);
        int f2 = calc_split(m2);
        ans = min(ans, min(f1, f2));
        
        if (f1 < f2) {
            R = m2;
        } else {
            L = m1;
        }
    }
    
    
    for (int i = L; i <= R; i++) {
        ans = min(ans, calc_split(i));
    }
    
    cout << ans;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);
    solve();
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...