Submission #1211315

#TimeUsernameProblemLanguageResultExecution timeMemory
1211315loiiii12358통행료 (IOI18_highway)C++20
6 / 100
47 ms18600 KiB
#include "highway.h"
#include <bits/stdc++.h>
using namespace std;
#define int long long

int n,a,b,len,s,e;
vector<bool> used;
vector<int32_t> para;
vector<vector<pair<int,int>>> edge;

int run(int u,int dep){
    int cur,l,r;
    stack<int> in;
    vector<int> depth(n,0);
    vector<bool> visited(n,false);
    vector<pair<int,int>> pos;
    in.push(u);visited[u]=true;
    while(!in.empty()){
        cur=in.top();
        in.pop();
        for(auto [v,id]:edge[cur]){
            if(!visited[v]&&!used[v]){
                visited[v]=true;
                depth[v]=depth[cur]+1;
                if(depth[v]==dep){
                    pos.push_back({id,v});
                }else{
                    in.push(v);
                }
            }
        }
    }
    l=0;r=pos.size()-1;
    while(l!=r){
        int m=l+(r-l)/2;
        fill(para.begin(),para.end(),0);
        for(int i=0;i<m;i++){
            para[pos[i].first]=1;
        }
        if(ask(para)>len*a){
            r=m;
        }else{
            l=m+1;
        }
    }
    return pos[l].second;
}

vector<int> subtrees(int u){
    int cur;
    stack<int> in,out;
    vector<int> subtree(n,1);
    vector<bool> visited(n,false);
    in.push(u);visited[u]=true;
    while(!in.empty()){
        cur=in.top();
        in.pop();
        out.push(cur);
        for(auto [v,id]:edge[cur]){
            if(!visited[v]&&!used[v]){
                in.push(v);
                visited[v]=true;
            }
        }
    }
    while(!out.empty()){
        cur=out.top();
        out.pop();
        for(auto [v,id]:edge[cur]){
            if(!visited[v]&&!used[v]){
                subtree[cur]+=subtree[v];
            }
        }
        visited[cur]=false;
    }
    return subtree;
}

int centroid(int u){
    int cur=u;
    bool run=true;
    vector<int> subtree=subtrees(u);
    vector<bool> visited(n,false);
    visited[u]=true;
    while(run){
        run=false;
        for(auto [v,id]:edge[cur]){
            if(subtree[v]>=subtree[u]/2&&!visited[v]&&!used[v]){
                cur=v;
                visited[v]=true;
                run=true;
                break;
            }
        }
    }
    return cur;
}

void solve(int u){
    u=centroid(u);
    // cout << u << '\n';
    // for(auto i:used){
    //     cout << i << ' ';
    // }
    // cout << '\n';
    int cnt=0;
    vector<int> subtree=subtrees(u),l,r;
    vector<pair<int,int>> pos;
    for(auto [v,id]:edge[u]){
        if(!used[v]){
            pos.push_back({subtree[v],v});
        }
    }
    sort(pos.begin(),pos.end());
    for(int i=pos.size()-1;i>=0;i--){
        if(cnt+pos[i].first<=subtree[u]/2){
            cnt+=pos[i].first;
            l.push_back(pos[i].second);
        }else{
            r.push_back(pos[i].second);
            used[pos[i].second]=true;
        }
    }
    int cur;
    stack<int> in;
    vector<bool> visited(n,false);
    in.push(u);visited[u]=true;fill(para.begin(),para.end(),0);
    while(!in.empty()){
        cur=in.top();
        in.pop();
        for(auto [v,id]:edge[cur]){
            if(!visited[v]&&!used[v]){
                in.push(v);
                visited[v]=true;
                para[id]=1;
            }
        }
    }
    cnt=(ask(para)-len*a)/(b-a);
    if(cnt==len){
        solve(u);
        return;
    }else if(cnt==0){
        for(auto i:l){
            used[i]=true;
        }
        for(auto i:r){
            used[i]=false;
        }
        solve(u);
        return;
    }
    s=run(u,cnt);
    for(auto i:l){
        used[i]=true;
    }
    for(auto i:r){
        used[i]=false;
    }
    e=run(u,len-cnt);
}

void find_pair(int32_t N, std::vector<int32_t> U, std::vector<int32_t> V, int32_t A, int32_t B) {
  edge.resize(N);used.resize(N,false);para.resize(U.size(),0);
  for(int i=0;i<U.size();i++){
    edge[U[i]].push_back({V[i],i});
    edge[V[i]].push_back({U[i],i});
  }
  n=N;a=A;b=B;len=ask(para)/A;
  if(len>1){
    solve(0);
  }else{
    int l=0,r=U.size()-1;
    while(l<r){
        int m=l+(r-l)/2;
        fill(para.begin(),para.end(),0);
        for(int i=0;i<m;i++){
            para[i]=1;
        }
        if(ask(para)){
            r=m;
        }else{
            l+m+1;
        }
    }
    s=U[l];
    e=V[l];
  }
  answer(s,e);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...