답안 #758375

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
758375 2023-06-14T14:21:01 Z NintsiChkhaidze Torrent (COI16_torrent) C++17
100 / 100
163 ms 28064 KB
#include <bits/stdc++.h>
#define ll long long
#define s second
#define pb push_back
#define f first
#define left (h<<1),l,((l+r)>>1)
#define right ((h<<1)|1),((l+r)>>1) + 1,r
#define pii pair<int,int>
using namespace std;

const int N = 3e5+5,inf = 1e9;
vector <int> pt,v[N];
bool f=0,fix[N];
int dp[N],val[N],t[N],cnt[N];

void findpath(int x,int par,int y){
    pt.pb(x);
    if (x == y) {f=1; return;}
    for (int to: v[x]){
        if (to == par) continue;
        findpath(to,x,y);
        if (f) return;
    }
    pt.pop_back();
}

int dfs(int x,int par){
    vector <int> k;
    for (int to: v[x]){
        if (to == par || fix[to]) continue;
        k.pb(dfs(to,x));
    }

    sort(k.begin(),k.end());
    int ans = 0;
    for (int i=(int)k.size()-1;i>=0;i--){
        ans = max(ans,k[i] + (int)k.size() - i);
    }
    return ans;
}
signed main (){
    ios_base::sync_with_stdio(0),cin.tie(NULL),cout.tie(NULL);
    
    int n,a,b;
    cin>>n>>a>>b;

    for (int i = 1; i < n; i++){
        int x,y;
        cin>>x>>y;
        v[x].pb(y);
        v[y].pb(x);
    }

    findpath(a,a,b);
    int m = pt.size();

    for (int x: pt) fix[x] = 1;
    for (int x: pt) dp[x] = dfs(x,x);
    
    int mid,l = max(dp[a],dp[b]),r = n,ans = n,x,pr,i;
    bool check;
    while (l <= r){
        mid = (l + r)>>1;
        for (i = 1;i <= n; i++)
            val[i] = t[i] = inf;
        
        check=1;
        val[a] = dp[a];
        val[b] = dp[b];
        t[a] = t[b] = 0;

        for (i = 1; i < m; i++){
            x = pt[i]; pr = pt[i - 1];
            if (val[pr] + 1 <= mid){
                t[x] = min(t[x],t[pr] + 1);
            }else{
                int p = t[pr] + (int)v[pr].size();
                if (pr == a || pr == b) p--;
                t[x] = min(t[x],p);
            }
            val[x] = min(val[x],t[x] + dp[x]);
        }

        for (i = m - 1; i >= 0; i--){
            x = pt[i]; pr = pt[i + 1];
            if (val[pr] + 1 <= mid){
                t[x] = min(t[x],t[pr] + 1);
            }else{
                int p = t[pr] + (int)v[pr].size();
                if (pr == a || pr == b) p--;
                t[x] = min(t[x],p);
            }
            val[x] = min(val[x],t[x] + dp[x]);
        }

        for (i = 1; i <= n; i++){
            if (fix[i] && val[i] > mid) check=0;
        }
        if (check){
            ans= mid;
            r=mid-1;
        }else{
            l=mid+1;
        }
    }

    cout<<ans;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 4 ms 7380 KB Output is correct
2 Correct 4 ms 7372 KB Output is correct
3 Correct 4 ms 7376 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 132 ms 25684 KB Output is correct
2 Correct 163 ms 26600 KB Output is correct
3 Correct 121 ms 27920 KB Output is correct
4 Correct 113 ms 27336 KB Output is correct
5 Correct 117 ms 25356 KB Output is correct
6 Correct 105 ms 25816 KB Output is correct
7 Correct 124 ms 28064 KB Output is correct