This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include "highway.h"
#include <vector>
#include <cassert>
using namespace std;
using ll = long long;
struct Edge {
int d, i;
};
vector<vector<Edge>> adj;
vector<bool> good;
vector<int> sz;
void dfs(int curr, int prev) {
sz[curr] = 1;
for (auto [c, i]: adj[curr]) {
if (c == prev || !good[c]) continue;
dfs(c, curr);
sz[curr] += sz[c];
}
}
int get_c(int curr, int prev, int tar) {
for (auto [c, i]: adj[curr]) {
if (c == prev || !good[c]) continue;
if (sz[c] >= tar) return get_c(c, curr, tar);
}
return curr;
}
void set_bad(int curr, int prev) {
good[curr] = false;
for (auto [c, i]: adj[curr]) {
if (c == prev) continue;
set_bad(c, curr);
}
}
void set_b(int curr, int prev, vector<int> &w) {
for (auto [c, i]: adj[curr]) {
if (c == prev || !good[c]) continue;
w[i] = true;
set_b(c, curr, w);
}
}
void add_candidates(int curr, int prev, int d, vector<Edge> &tars) {
for (auto [c, i]: adj[curr]) {
if (c == prev || !good[c]) continue;
if (d == 1) tars.push_back({c, i});
else add_candidates(c, curr, d-1, tars);
}
}
// guaranteed d >= 1
int find_one(int centroid, vector<Edge> sources, int d, ll distA) {
vector<Edge> candidates;
for (auto [c, i]: sources) {
if (d == 1) candidates.push_back({c, i});
else add_candidates(c, centroid, d-1, candidates);
}
int low = 0, high = candidates.size();
while (low+1 < high) {
int mid = (low+high)/2;
vector<int> w(adj.size()-1, false);
for (int i = low; i < mid; i++) w[candidates[i].i] = true;
if (ask(w) != distA) high = mid;
else low = mid;
}
return candidates[low].d;
}
void find_pair(int N, vector<int> U, vector<int> V, int A, int B) {
const int M = U.size();
assert(M == N-1); // TODO
adj.resize(N);
good.resize(N, true);
sz.resize(N);
for (int i = 0; i < M; i++) {
adj[U[i]].push_back({V[i], i});
adj[V[i]].push_back({U[i], i});
}
// distance between S and T
ll dist = ask(vector<int>(M, 0)) / A;
// tree is always rooted at the centroid
int centroid = 0;
while (true) {
dfs(centroid, -1);
centroid = get_c(centroid, -1, (sz[centroid]+1) / 2);
// centroid becomes root
dfs(centroid, -1);
if (sz[centroid] == 2) {
for (auto [c, i]: adj[centroid]) {
if (good[c]) {
answer(centroid, c);
return;
}
}
assert(false);
}
vector<int> w(M, false);
int cut = 0, currSz = 0;
while (currSz < (sz[centroid]-1) / 2) {
auto [c, i] = adj[centroid][cut];
if (!good[c]) {
cut++;
continue;
}
w[i] = true;
set_b(c, centroid, w);
currSz += sz[c];
cut++;
}
assert(currSz > 0 && currSz < sz[centroid]);
ll res = ask(w);
if (res == A*dist) {
for (int i = 0; i < cut; i++) {
set_bad(adj[centroid][i].d, centroid);
}
} else if (res == B*dist) {
for (int i = cut; i < adj[centroid].size(); i++) {
set_bad(adj[centroid][i].d, centroid);
}
} else {
const int x = (res - A*dist) / (B - A);
const int y = dist - x;
answer(
x == 0 ? centroid : find_one(centroid, vector<Edge>(adj[centroid].begin(), adj[centroid].begin() + cut), x, dist*A),
y == 0 ? centroid : find_one(centroid, vector<Edge>(adj[centroid].begin() + cut, adj[centroid].end()), y, dist*A)
);
return;
}
}
}
Compilation message (stderr)
highway.cpp: In function 'void find_pair(int, std::vector<int>, std::vector<int>, int, int)':
highway.cpp:128:33: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<Edge>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
128 | for (int i = cut; i < adj[centroid].size(); i++) {
| ~~^~~~~~~~~~~~~~~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |