답안 #1067795

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1067795 2024-08-21T03:22:16 Z thinknoexit Torrent (COI16_torrent) C++17
100 / 100
826 ms 36756 KB
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 300300;
vector<int> adj[N], adj2[N];
vector<int> nodes;
bool vis[N], in[N];
int X, Y;
// minimum time to fill in subtree when node i is active
int dp[N], pref[N], suff[N];
bool find_path(int v) {
    vis[v] = 1;
    if (v == Y) {
        in[v] = 1;
        nodes.push_back(v);
        return true;
    }
    for (auto& x : adj2[v]) {
        if (vis[x]) continue;
        if (find_path(x)) {
            in[v] = 1;
            nodes.push_back(v);
            return true;
        }
    }
    return false;
}
void dfs1(int v, int p = -1) {
    for (auto& x : adj2[v]) {
        if (in[x] || x == p) continue;
        adj[v].push_back(x);
        dfs1(x, v);
    }
}
void dfs(int v) {
    for (auto& x : adj[v]) dfs(x);
    dp[v] = 0;
    sort(adj[v].begin(), adj[v].end(),
        [&](int a, int b) {return dp[a] > dp[b];});
    int c = 0;
    for (auto& x : adj[v]) {
        dp[v] = max(dp[v], dp[x] + (++c));
    }
}
int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    int n;
    cin >> n >> X >> Y;
    for (int i = 1;i < n;i++) {
        int u, v;
        cin >> u >> v;
        adj2[u].push_back(v);
        adj2[v].push_back(u);
    }
    // find node in path between (x, y)
    find_path(X);
    reverse(nodes.begin(), nodes.end());
    // calculate dp
    for (int i = 1;i <= n;i++) {
        if (in[i]) dfs1(i), dfs(i);
    }
    for (int i = 1;i <= n;i++) {
        adj2[i].clear();
        //cout << dp[i] << ' ';
    }
    //cout << '\n';
    int l = max(dp[X], dp[Y]), r = n - 1;
    int sz = nodes.size();
    if (sz == 2) {
        cout << l << '\n';
        return 0;
    }
    while (l < r) {
        int mid = (l + r) / 2;
        int l1, r1;
        // find maximum possible index x can reach
        l1 = 0, r1 = sz - 2;
        while (l1 < r1) {
            int mid1 = (l1 + r1 + 1) / 2;
            {
                int v = nodes[mid1];
                dp[v] = 0;
                sort(adj[v].begin(), adj[v].end(),
                    [&](int a, int b) {return dp[a] > dp[b];});
                int c = 0;
                for (auto& x : adj[v]) {
                    dp[v] = max(dp[v], dp[x] + (++c));
                }
            }
            for (int j = mid1 - 1;j >= 0;j--) {
                int u = nodes[j], v = nodes[j + 1];
                dp[u] = 0;
                adj2[u] = adj[u];
                adj2[u].push_back(v);
                sort(adj2[u].begin(), adj2[u].end(),
                    [&](int a, int b) {return dp[a] > dp[b];});
                int c = 0;
                for (auto& x : adj2[u]) {
                    dp[u] = max(dp[u], dp[x] + (++c));
                }
                adj2[u].clear();
            }
            if (dp[X] <= mid) l1 = mid1;
            else r1 = mid1 - 1;
        }
        int ansx = l1;
        // find minimum possible index y can reach
        l1 = 1, r1 = sz - 1;
        while (l1 < r1) {
            int mid1 = (l1 + r1) / 2;
            {
                int v = nodes[mid1];
                dp[v] = 0;
                sort(adj[v].begin(), adj[v].end(),
                    [&](int a, int b) {return dp[a] > dp[b];});
                int c = 0;
                for (auto& x : adj[v]) {
                    dp[v] = max(dp[v], dp[x] + (++c));
                }
            }
            for (int j = mid1 + 1;j < sz;j++) {
                int u = nodes[j], v = nodes[j - 1];
                dp[u] = 0;
                adj2[u] = adj[u];
                adj2[u].push_back(v);
                sort(adj2[u].begin(), adj2[u].end(),
                    [&](int a, int b) {return dp[a] > dp[b];});
                int c = 0;
                for (auto& x : adj2[u]) {
                    dp[u] = max(dp[u], dp[x] + (++c));
                }
                adj2[u].clear();
            }
            if (dp[Y] <= mid) r1 = mid1;
            else l1 = mid1 + 1;
        }
        int ansy = l1;
        //cout << mid << ' ' << ansx << ' ' << ansy << '\n';
        // shift (l, r)
        if (ansx + 1 >= ansy) r = mid;
        else l = mid + 1;
    }
    cout << l << '\n';
    return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 7 ms 14424 KB Output is correct
2 Correct 6 ms 14428 KB Output is correct
3 Correct 7 ms 14428 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 307 ms 30828 KB Output is correct
2 Correct 424 ms 35412 KB Output is correct
3 Correct 713 ms 35724 KB Output is correct
4 Correct 495 ms 36152 KB Output is correct
5 Correct 296 ms 35152 KB Output is correct
6 Correct 521 ms 35412 KB Output is correct
7 Correct 826 ms 36756 KB Output is correct