Submission #1364471

#TimeUsernameProblemLanguageResultExecution timeMemory
1364471marizaHighway Tolls (IOI18_highway)C++20
51 / 100
289 ms327680 KiB
#include "highway.h"
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const ll N4=9e4;
#define MID ((l+r)/2)

vector<ll> t[N4];

ll d[N4];
void dfs(ll curr, ll prev){
    for(auto nxt:t[curr]){
        if(nxt==prev) continue;
        d[nxt]=d[curr]+1;
        dfs(nxt,curr);
    }
}

ll y[N4]={};
bool mark(ll curr, ll prev, ll x){
    if(curr==x) y[curr]=1;

    for(auto nxt:t[curr]){
        if(nxt==prev) continue;
        if(mark(nxt,curr,x)) y[curr]=1;
    }

    return y[curr];
}

void find_pair(int n, vector<int> u, vector<int> v, int a, int b){
    for(ll i=0; i<n-1; i++){
        t[u[i]].push_back(v[i]);
        t[v[i]].push_back(u[i]);
    }

    d[0]=0;
    dfs(0,0);

    ll maxd=0;
    for(ll i=0; i<n; i++){
        maxd=max(maxd,d[i]);
    }

    ll dist=ask(vector<int>(n-1,0));

    ll l=0, r=maxd, lca;
    while(l<=r){
        vector<int> x(n-1,0);
        for(ll i=0; i<n-1; i++){
            if(max(d[u[i]],d[v[i]])<=MID) x[i]=1;
        }

        if(ask(x)==dist){
            lca=MID;
            l=MID+1;
        }
        else r=MID-1;
    }

    l=0; r=maxd;
    ll d2;
    while(l<=r){
        vector<int> x(n-1,0);
        for(ll i=0; i<n-1; i++){
            if(min(d[u[i]],d[v[i]])>=MID) x[i]=1;
        }

        if(ask(x)==dist){
            d2=MID;
            r=MID-1;
        }
        else l=MID+1;
    }

    ll d1=lca+dist/a-(d2-lca);

    // cout<<d1<<" "<<d2<<endl;

    vector<ll> e1, e2;
    for(ll i=0; i<n-1; i++){
        if(max(d[u[i]],d[v[i]])==d1) e1.push_back(i);
        if(max(d[u[i]],d[v[i]])==d2) e2.push_back(i);
    }

    l=0; r=e2.size()-1;
    ll idx2;
    while(l<=r){
        vector<int> x(n-1,0);
        for(ll i=l; i<=MID; i++){
            x[e2[i]]=1;
        }

        if(ask(x)>dist){
            idx2=MID;
            r=MID-1;
        }
        else l=MID+1;
    }
    ll ans2;
    if(d[u[e2[idx2]]]==d2) ans2=u[e2[idx2]];
    else ans2=v[e2[idx2]];

    mark(0,0,ans2);
    ll z=-2;
    for(ll i=0; i<n; i++){
        // cout<<i<<" "<<y[i]<<" "<<d[i]<<endl;
        if(y[i] && d[i]==d1) z=i;
    }

    l=0; r=e1.size()-1;
    ll idx1=-1;
    while(l<=r){
        vector<int> x(n-1,0);
        for(ll i=l; i<=MID; i++){
            if(u[e1[i]]!=z && v[e1[i]]!=z) x[e1[i]]=1;
        }

        if(ask(x)>dist){
            idx1=MID;
            r=MID-1;
        }
        else l=MID+1;
    }
    ll ans1;
    if(idx1==-1) ans1=z;
    else if(d[u[e1[idx1]]]==d1) ans1=u[e1[idx1]];
    else ans1=v[e1[idx1]];

    answer(ans1,ans2);
    // cout<<ans1<<" "<<ans2<<endl;
}
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...