Submission #290588

#TimeUsernameProblemLanguageResultExecution timeMemory
290588ChrisTHighway Tolls (IOI18_highway)C++17
51 / 100
673 ms29416 KiB
#include <bits/stdc++.h> #include "highway.h" using namespace std; const int MN = 9e4 + 5; vector<pair<int,int>> adj[MN]; vector<pair<int,int>> atDep[MN]; pair<int,int> par[MN]; vector<int> w; int depth[MN], mxDep; void dfs (int cur, int p = -1) { if (~p) atDep[depth[cur]].push_back({cur,par[cur].second}); mxDep = max(mxDep,depth[cur]); for (auto [i,j] : adj[cur]) if (i != p) { par[i] = {cur,j}; depth[i] = depth[cur] + 1; dfs(i,cur); } } void find_pair (int n, vector<int> u, vector<int> v, int a, int b) { assert((int)u.size() == n - 1); for (int i = 0; i + 1 < n; i++) { adj[++u[i]].emplace_back(++v[i],i); adj[v[i]].emplace_back(u[i],i); } w.resize(n-1); long long smallDist = ask(w); dfs(1); auto find = [&] (set<int> banned) { int low = 1, high = mxDep, mid, ans = -1; while (low <= high) { mid = (low + high) / 2; for (int i = mid; i <= mxDep; i++) for (auto p : atDep[i]) if (!banned.count(p.second)) w[p.second] = 1; if (ask(w) != smallDist) low = (ans = mid) + 1; else high = mid - 1; for (int i = mid; i <= mxDep; i++) for (auto p : atDep[i]) if (!banned.count(p.second))w[p.second] = 0; } if (!~ans) return -1; function<int(int,int)> get = [&] (int l, int r) { if (l == r) return atDep[ans][l].first; int mid = (l + r) / 2; for (int i = l; i <= mid; i++) if (!banned.count(atDep[ans][i].second)) w[atDep[ans][i].second] = 1; long long got = ask(w); for (int i = l; i <= mid; i++) if (!banned.count(atDep[ans][i].second)) w[atDep[ans][i].second] = 0; if (got != smallDist) return get(l,mid); return get(mid+1,r); }; return get(0,(int)atDep[ans].size() - 1); }; int s = find({}); set<int> ban; int cur = s; while (cur != 1) { ban.insert(par[cur].second); cur = par[cur].first; } int t = find(ban); if (!~t) { //answer is on root-->s path vector<pair<int,int>> go; cur = s; while (cur != 1) { go.push_back(par[cur]); cur = par[cur].first; } int low = 0, high = (int)go.size() - 1,mid,ans=-1; while (low <= high) { mid = (low + high) / 2; w[go[mid].second] = 1; if (ask(w) != smallDist) low = (ans = mid) + 1; else high = mid - 1; w[go[mid].second] = 0; } assert(~ans); t = go[ans].first; } answer(--s,--t); }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...