Submission #423994

#TimeUsernameProblemLanguageResultExecution timeMemory
423994TangentHighway Tolls (IOI18_highway)C++17
51 / 100
369 ms262148 KiB
#include "highway.h"
#include "bits/stdc++.h"

using namespace std;

typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef vector<int> vii;
typedef vector<ll> vll;
typedef vector<pii> vpii;
typedef vector<pll> vpll;
typedef vector<vii> vvii;
typedef vector<vll> vvll;
typedef vector<vpii> vvpii;
typedef vector<vpll> vvpll;

#define ffor(i, a, b) for (ll i = a; i < b; i++)
#define rep(i, n) ffor(i, 0, n)
#define forin(x, a) for (auto &x: a)
#define all(a) a.begin(), a.end()

pii solve(int N, std::vector<int> U, std::vector<int> V, int A, int B, vii &verts, int vertcount) {
  int M = U.size();
  vvpii adj(N);
  rep(i, M) {
    if (verts[U[i]] && verts[V[i]]) {
      adj[U[i]].emplace_back(V[i], i);
      adj[V[i]].emplace_back(U[i], i);
    }
  }

  vii cnt(N, 1), depth(N), pedge(N, -1);
  vvpii children(N);

  function<void(int, int)> dfs;
  dfs = [&](int x, int p) {
    forin(y, adj[x]) {
      if (y.first != p) {
        children[x].emplace_back(y);
        depth[y.first] = depth[x] + 1;
        pedge[y.first] = y.second;
        dfs(y.first, x);
        cnt[x] += cnt[y.first];
      }
    }
  };
  rep(i, N) {
    if (verts[i]) {
      dfs(i, -1);
      break;
    }
  }
  int root = -1, bestmax = N;
  rep(i, N) {
    if (!verts[i]) continue;
    int currmax = vertcount - cnt[i];
    forin(j, children[i]) {
      currmax = max(currmax, cnt[j.first]);
    }
    if (currmax < bestmax) {
      root = i;
      bestmax = currmax;
    }
  }

  cnt.assign(N, 1);
  depth.assign(N, 0);
  pedge.assign(N, -1);
  children.clear();
  children.resize(N);
  dfs(root, -1);

  vii query(M);
  ll dist = ask(query) / A;
  forin(child, children[root]) {
    query[child.second] = 1;
  }
  ll rootans = (ask(query) - (A * dist)) / (B - A);
  forin(child, children[root]) {
    query[child.second] = 0;
  }

  if (rootans == 0) {
    vvii egroups(children[root].size());
    rep(i, children[root].size()) {
      deque<int> q = {children[root][i].first};
      while (!q.empty()) {
        forin(child, children[q.front()]) {
          q.emplace_back(child.first);
          egroups[i].emplace_back(child.second);
        }
        q.pop_front();
      }
    }

    int lo = 0, hi = egroups.size() - 1;
    while (lo < hi) {
      vii query2(M);
      int mid = (lo + hi + 1) / 2;
      ffor(i, lo, mid) {
        forin(e, egroups[i]) {
          query2[e] = 1;
        }
      }
      if (ask(query2) > dist * A) {
        hi = mid - 1;
      } else {
        lo = mid;
      }
    }
    vii nverts(N);
    int nvertcount = 0;
    deque<int> q = {children[root][lo].first};
    while (!q.empty()) {
      nverts[q.front()] = 1;
      nvertcount++;
      forin(child, children[q.front()]) {
        q.emplace_back(child.first);
      }
      q.pop_front();
    }
    return solve(N, U, V, A, B, nverts, nvertcount);
  } else if (rootans == 1) {
    vii cand;
    rep(i, N) {
      if (verts[i] && depth[i] == dist) {
        cand.emplace_back(i);
      }
    }
    int lo = 0, hi = cand.size() - 1;
    while (lo < hi) {
      vii query2(M);
      int mid = (lo + hi + 1) / 2;
      ffor(i, lo, mid) {
        query2[pedge[cand[i]]] = 1;
      }
      if (ask(query2) > dist * A) {
        hi = mid - 1;
      } else {
        lo = mid;
      }
    }
    return {root, cand[lo]};
  } else {
    int lo = 0, hi = children[root].size() - 1;
    vii child_inds, res;
    while (lo < hi) {
      vii query2(M);
      int mid = (lo + hi + 1) / 2;
      ffor(i, lo, mid) {
        query2[children[root][i].second] = 1;
      }
      ll childans = (ask(query2) - (A * dist)) / (B - A);
      if (childans == 0) {
        lo = mid;
      } else if (childans == 1) {
        int lo1 = lo, hi1 = mid - 1, lo2 = mid, hi2 = hi;
        while (lo1 < hi1) {
          vii query3(M);
          int mid1 = (lo1 + hi1 + 1) / 2;
          ffor(i, lo1, mid1) {
            query3[children[root][i].second] = 1;
          }
          if (ask(query3) > dist * A) {
            hi1 = mid1 - 1;
          } else {
            lo1 = mid1;
          }
        }
        while (lo2 < hi2) {
          vii query3(M);
          int mid2 = (lo2 + hi2 + 1) / 2;
          ffor(i, lo2, mid2) {
            query3[children[root][i].second] = 1;
          }
          if (ask(query3) > dist * A) {
            hi2 = mid2 - 1;
          } else {
            lo2 = mid2;
          }
        }
        child_inds = {lo1, lo2};
        break;
      } else {
        hi = mid - 1;
      }
    }
    forin(ind, child_inds) {
      vii query2(M);
      vii vgroup;
      deque<int> q = {children[root][ind].first};
      while (!q.empty()) {
        vgroup.emplace_back(q.front());
        forin(child, children[q.front()]) {
          q.emplace_back(child.first);
          query2[child.second] = 1;
        }
        q.pop_front();
      }
      ll cdist = (ask(query2) - (dist * A)) / (B - A);
      vii cand;
      forin(i, vgroup) {
        if (depth[i] == cdist + 1) {
          cand.emplace_back(i);
        }
      }
      int lo1 = 0, hi1 = cand.size() - 1;
      while (lo1 < hi1) {
        vii query3(M);
        int mid1 = (lo1 + hi1 + 1) / 2;
        ffor(i, lo1, mid1) {
          query3[pedge[cand[i]]] = 1;
        }
        if (ask(query3) > dist * A) {
          hi1 = mid1 - 1;
        } else {
          lo1 = mid1;
        }
      }
      res.emplace_back(cand[lo1]);
    }
    return {res[0], res[1]};
  }
}


void find_pair(int N, std::vector<int> U, std::vector<int> V, int A, int B) {
  vii curr(N, 1);
  auto res = solve(N, U, V, A, B, curr, N);
  answer(res.first, res.second);
}

Compilation message (stderr)

highway.cpp: In function 'pii solve(int, std::vector<int>, std::vector<int>, int, int, vii&, int)':
highway.cpp:18:40: warning: comparison of integer expressions of different signedness: 'll' {aka 'long long int'} and 'std::vector<std::pair<int, int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   18 | #define ffor(i, a, b) for (ll i = a; i < b; i++)
      |                                        ^
highway.cpp:19:19: note: in expansion of macro 'ffor'
   19 | #define rep(i, n) ffor(i, 0, n)
      |                   ^~~~
highway.cpp:86:5: note: in expansion of macro 'rep'
   86 |     rep(i, children[root].size()) {
      |     ^~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...