Submission #805957

#TimeUsernameProblemLanguageResultExecution timeMemory
805957fatemetmhrHighway Tolls (IOI18_highway)C++17
100 / 100
216 ms33748 KiB
//  ~ Be Name Khoda ~  //
 
#include "highway.h"
#include <bits/stdc++.h>
//#pragma GCC optimize ("O3")
//#pragma GCC target("avx2")
//#pragma GCC optimize("unroll-loops,Ofast")
 
using namespace std;
 
typedef long long ll;
 
#define pb       push_back
#define mp       make_pair
#define all(x)   x.begin(), x.end()
#define fi       first
#define se       second
 
const int maxn  =  1e6   + 10;
const int maxn5 =  2e5   + 10;
const int maxnt =  1.2e6 + 10;
const int maxn3 =  1e3   + 10;
const int mod   =  1e9   + 7;
const ll  inf   =  1e18;
 
vector <int> tok, adj[maxn5], v, u, ver[2], ed[2], jda[maxn5];
ll len, h[maxn5];
int m, q[maxn5], par[maxn5], kh[maxn5], gooded;
int kparid[maxn5], parid[maxn5], kpar[maxn5];
ll pas[2][maxn5];
bool mark[maxn5], good[maxn5];
 
void bfs(int rt){
    par[rt] = -1;
    memset(h, -1, sizeof h);
    int l = 0, r = 1;
    q[0] = rt;
    h[rt] = 0;
    while(l < r){
        int a = q[l++];
        for(auto id : adj[a]){
            int b = u[id] == a ? v[id] : u[id];
            if(h[b] != -1)
                continue;
            //cout << "here " << a << ' ' << b << ' ' << id << endl;
            q[r++] = b;
            par[b] = a;
            h[b] = h[a] + 1;
            parid[b] = id;
        }
    }
}
 
void dfs(int v, int ty){
    mark[v] = true;
    //cout << v << ' ' << par[v] << endl;
    for(auto id : jda[v]){
        int u = (::u[id]) == v ? (::v[id]) : (::u[id]);
        if(mark[u])
            continue;
        //cout << "in " << v << ' ' << u << ' ' << ty << endl;
        dfs(u, ty);
        ver[ty].pb(u);
        ed[ty].pb(id);
    }
}
 
ll get(int id, int ty){
    if(pas[ty][id] == -1){
        fill(all(tok), 1);
        tok[gooded] = false;
        for(auto u : ed[ty ^ 1])
            tok[u] = false;
        for(int i = id + 1; i < int(ed[ty].size()); i++)
            tok[ed[ty][i]] = false;
        pas[ty][id] = ask(tok);
    }
    return pas[ty][id];
}
 
 
void find_pair(int n, std::vector<int> U, std::vector<int> V, int a, int b){
    memset(pas, -1, sizeof pas);
    u = U;
    v = V;
    m = u.size();
    for(int i = 0; i < m; i++){
        adj[u[i]].pb(i);
        adj[v[i]].pb(i);
    }
    int lo = -1, hi = m;
    tok.resize(m);
    fill(all(tok), 0);
    len = ask(tok) / a;
    while(hi - lo > 1){
        int mid = (lo + hi) >> 1;
        fill(all(tok), 0);
        fill(tok.begin(), tok.begin() + mid + 1, 1);
        ll cost = ask(tok);
        if(cost == len * a)
            lo = mid;
        else
            hi = mid;
    }
    gooded = hi;
    //cout << "ok " << hi << ' ' << u[hi] << ' ' << v[hi] << endl;
    bfs(u[hi]);
    for(int i = 0; i < n; i++){
        kh[i] = h[i];
        kpar[i] = par[i];
        kparid[i] = parid[i];
    }
    bfs(v[hi]);
    for(int i = 0; i < n; i++){
        //cout << i << ' ' << kh[i] << ' ' << h[i] << ' ' << par[i] << ' ' << kpar[i] << endl;
        if(h[i] < kh[i] && par[i] != -1)
            jda[par[i]].pb(parid[i]);
        if(h[i] > kh[i] && kpar[i] != -1){
            //cout << "Pushing " << i << ' ' << kpar[i] << endl;
            jda[kpar[i]].pb(kparid[i]);
        }
    }
    int rt[2] = {u[hi], v[hi]};
    memset(mark, false, sizeof mark);
    //cout << rt[0] << ' ' << rt[1] << endl;
    dfs(rt[0], 0);
    dfs(rt[1], 1);
    /*
    for(int i = 0; i < n - 1; i++)
        //cout << ed[i] << ' ' << ver[i] << endl;
    //*/
    ////cout << "its " << hi << endl;
    lo = -1; hi = ed[0].size();
    while(hi - lo > 1){
        int mid = (lo + hi) >> 1;
        if(get(mid, 0) == len * a)
            lo = mid;
        else
            hi = mid;
    }
    int s = (hi == int(ed[0].size()) ? rt[0] : ver[0][hi]);
    //cout << lo << ' ' << hi << ' ' << ed[0].size() << ' ' << s << endl;
    lo = -1; hi = ed[1].size();
    while(hi - lo > 1){
        int mid = (lo + hi) >> 1;
        if(get(mid, 1) == len * a)
            lo = mid;
        else
            hi = mid;
    }
    int t = (hi == int(ed[1].size()) ? rt[1] : ver[1][hi]);
    //cout << t << ' ' << lo << ' ' << hi << ' ' << ed[1].size() << endl;
    answer(s, t); 
}
 
 
 
 
 
 
 
 
 
 
 
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...