Submission #1242593

#TimeUsernameProblemLanguageResultExecution timeMemory
1242593nikdHighway Tolls (IOI18_highway)C++20
5 / 100
74 ms16736 KiB
#include "highway.h"
#include <bits/stdc++.h>
using namespace std;

void find_pair(int n, std::vector<int> u, std::vector<int> v, int a, int b) {
  int m = u.size();
  if(m!=n-1) assert(0);
  vector<vector<array<int, 2>>> adj(n);
  for(int i = 0; i<m; i++){
    adj[u[i]].push_back({v[i], i});
    adj[v[i]].push_back({u[i], i});
  }
  vector<int> h(n);
  int h_max = 0;
  vector<int> par(n, -1);
  auto dfs = [&](auto&& dfs, int v, int p)->void{
    par[v] = p;
    h_max = max(h[v], h_max);
    for(auto [u, i]: adj[v]){
      if(u==p) continue;
      h[u] = h[v]+1;
      dfs(dfs, u, v);
    }
    return;
  };
  h[0] = 0;
  dfs(dfs, 0, -1);
  vector<vector<int>> layers(h_max+1);
  for(int i = 0; i<n; i++) layers[h[i]].push_back(i);
  vector<int> w(m, 0);
  int d = ask(w)/a;
  int l = 0; int r = h_max;
  while(r-l>1){
    int mid = (l+r)/2;
    for(int i = 0; i<m; i++) w[i] = 0;
    for(int i = h_max-1; i>=mid; i--){
      for(int j: layers[i]){
        for(auto [u, idx]: adj[j]){
          if(u == par[j]) continue;
          w[idx] = 1;
        }
      }
    }
    int res = ask(w);
    if(res > d*a) l = mid;
    else r = mid;
  }
  int height_lowest = l+1;
  l = -1; r = layers[height_lowest].size()-1;
  while(r-l>1){
    int mid = (l+r+1)/2;
    for(int i = 0; i<m; i++) w[i] = 0;
    for(int i = 0; i<=mid; i++){
      int j = layers[height_lowest][i];
      for(auto [u, idx]: adj[j]){
        if(u == par[j]) w[idx] = 1;
      }
    }
    int res = ask(w);
    if(res > d*a) r = mid;
    else l = mid;
  }
  int s = layers[height_lowest][r];
  //swap(layers[height_lowest][r], layers[height_lowest][0]);
  vector<int> leaves;
  leaves.push_back(s);
  for(int i = 0; i<n; i++){
    if(i != s && i != 0 && (h[i] == height_lowest || (adj[i].size() == 1 && h[i] < height_lowest))) leaves.push_back(i);
  }
  vector<int> col(m, INT_MAX);
  vector<int> col_nd(n, INT_MAX);
  col_nd[0] =0;
  for(int i = 0; i<leaves.size(); i++){
    int nd = leaves[i];
    while(nd!=-1){
      if(col_nd[nd] != INT_MAX) break;
      col_nd[nd] = i;
      nd = par[nd];
    }
  }
  for(int v = 0; v<n; v++){
    for(auto [u, idx]: adj[v]){
      if(u==par[v])continue;
      col[idx] = col_nd[u];
    }
  }
  l = -1; r = leaves.size()-1;
  while(r-l>1){
    int mid = (l+r+1)/2;
    for(int i =0; i<m; i++){
      if(col[i] > mid) w[i] =0;
      else w[i] = 1;
    }  
    int res = ask(w);
    if(res == d*b) r = mid;
    else l = mid; 
  }
  if(r == 0){
    int curr_d = 0;
    int nd = s;
    while(curr_d < d){
      nd = par[nd];
      curr_d++;
    }
    answer(s, nd);
  }
  else{
    int lca = leaves[r];
    while(col_nd[lca] != 0) lca = par[lca]; 
    int t = leaves[r];
    int curr_d = height_lowest - h[lca]+h[leaves[r]]-h[lca];
    while(curr_d > d){
      t = par[t];
      curr_d--;
    }
    answer(s, t);
  }
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...