Submission #594021

#TimeUsernameProblemLanguageResultExecution timeMemory
594021Soumya1Highway Tolls (IOI18_highway)C++17
18 / 100
171 ms31012 KiB
#include "highway.h"
#include <bits/stdc++.h>
#ifdef __LOCAL__
  #include <debug_local.h>
#endif
using namespace std;
const int mxN = 90005;
vector<pair<int, int>> ad[mxN];
vector<int> l[mxN];
int dep[mxN], pid[mxN], par[mxN];
void dfs(int u, int p, int d = 0) {
  dep[u] = d;
  l[d].push_back(u);
  for (auto [v, id] : ad[u]) {
    if (v == p) continue;
    par[v] = u;
    pid[v] = id;
    dfs(v, u, d + 1);
  }
}
void find_pair(int n, vector<int> U, vector<int> V, int A, int B) {
  long long a = A, b = B;
  for (int i = 0; i < n - 1; i++) {
    ad[U[i]].push_back({V[i], i});
    ad[V[i]].push_back({U[i], i});
  }
  pid[0] = -1;
  dfs(0, -1);
  long long len = ask(vector<int> (n - 1)) / a;
  int lo = 0, hi = n - 1;
  while (lo < hi) {
    int mid = (lo + hi) >> 1;
    vector<int> v(n - 1);
    for (int i = 0; i <= mid; i++) {
      for (int j : l[i]) {
        if (pid[j] != -1) v[pid[j]] = 1;
      }
    }
    if (ask(v) == 1LL * len * b) hi = mid;
    else lo = mid + 1;
  }
  int d1 = lo;
  lo = 0, hi = n - 1;
  while (lo < hi) {
    int mid = (lo + hi + 1) >> 1;
    vector<int> v(n - 1);
    for (int i = 0; i <= mid; i++) {
      for (int j : l[i]) {
        if (pid[j] != -1) v[pid[j]] = 1;
      }
    }
    if (ask(v) == 1LL * len * a) lo = mid;
    else hi = mid - 1;
  }
  int lcad = lo;
  int d2 = len - d1 + 2 * lo;
  lo = 0, hi = l[d1].size() - 1;
  while (lo < hi) {
    int mid = (lo + hi) >> 1;
    vector<int> v(n - 1);
    for (int i = 0; i <= mid; i++) {
      int j = l[d1][i];
      if (pid[j] != -1) v[pid[j]] = 1;
    }
    if (ask(v) >= 1LL * (len - 1) * a + b) hi = mid;
    else lo = mid + 1;
  }
  int s = l[d1][lo];
  if (lcad == d2) {
    int t = s;
    while (len--) t = par[t];
    answer(s, t);
    return;
  }
  int up = d1 - d2;
  int k = s;
  while (up--) k = par[k];
  l[d2].erase(find(l[d2].begin(), l[d2].end(), k));
  lo = 0, hi = l[d2].size() - 1;
  while (lo < hi) {
    int mid = (lo + hi) >> 1;
    vector<int> v(n - 1);
    for (int i = 0; i <= mid; i++) {
      int j = l[d2][i];
      if (pid[j] != -1) v[pid[j]] = 1;
    }
    if (ask(v) == 1LL * (len - 1) * a + b) hi = mid;
    else lo = mid + 1;
  }
  int t = l[d2][lo];
  answer(s, t);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...