#include <algorithm>
#include <functional>
#include <iostream>
#include <numeric>
#include <random>
#include <utility>
#include <vector>
#include "longesttrip.h"
using namespace std;
bool safe_are_connected(vector<int> S_left, vector<int> S_right) {
if(S_left.size() == 0 || S_right.size() == 0) return false;
return are_connected(S_left, S_right);
}
vector<vector<int>> create_adj(
int N, vector<int> nodes, vector<pair<int, int>> edges
) {
vector<vector<int>> adj(N);
vector<bool> in_nodes(N, false);
for(auto n: nodes) in_nodes[n] = true;
for(auto e: edges) {
if(!in_nodes[e.first] || !in_nodes[e.second]) continue;
adj[e.first].push_back(e.second);
adj[e.second].push_back(e.first);
}
return adj;
}
void prune_tree(
int l1, int l2, int l3, vector<int> &leaves, vector<int> &par,
vector<int> °
) {
leaves.push_back(l3);
int u = l1, v = l2;
while(deg[u] == 1) {
int pu = par[u];
par[u] = v;
deg[v]++;
deg[pu]--;
v = u;
u = pu;
}
leaves.push_back(v);
}
vector<int> solve_tree(int N, vector<int> nodes, vector<vector<int>> adj) {
// We assume that the tree is connected here
int root = nodes[0];
vector<int> deg(N, 0);
for(auto node: nodes) deg[node] = adj[node].size();
for(auto node: nodes) {
if(deg[node] >= 3) {
root = node;
break;
}
}
vector<int> par(nodes.size(), 0);
vector<int> leaves;
function<void(int, int)> dfs = [&](int node, int parent) {
par[node] = parent;
bool is_leaf = true;
for(auto child: adj[node]) {
if(child == parent) continue;
dfs(child, node);
is_leaf = false;
}
if(is_leaf) leaves.push_back(node);
};
dfs(root, -1);
mt19937 mt(42);
while(leaves.size() >= 3) {
shuffle(leaves.begin(), leaves.end(), mt);
int l1 = leaves.back();
leaves.pop_back();
int l2 = leaves.back();
leaves.pop_back();
int l3 = leaves.back();
leaves.pop_back();
if(are_connected({l1}, {l2})) {
prune_tree(l1, l2, l3, leaves, par, deg);
} else if(are_connected({l1}, {l3})) {
prune_tree(l1, l3, l2, leaves, par, deg);
} else {
// Delta >= 1, means that l2 and l3 are connected
prune_tree(l2, l3, l1, leaves, par, deg);
}
}
// It's a path
vector<int> path;
vector<bool> used(N, false);
int u = leaves[0];
while(u != -1) {
path.push_back(u);
used[u] = true;
u = par[u];
}
if(leaves.size() == 2) {
u = leaves[1];
vector<int> rev_path;
while(!used[u]) {
rev_path.push_back(u);
u = par[u];
}
reverse(rev_path.begin(), rev_path.end());
path.insert(path.end(), rev_path.begin(), rev_path.end());
}
return path;
}
pair<int, int> find_one_edge(
vector<int> S_left, vector<int> S_right, mt19937 &mt
) {
assert(S_left.size() >= 1 && S_right.size() >= 1);
if(S_left.size() == 1 && S_right.size() == 1) {
return {S_left[0], S_right[0]};
}
shuffle(S_left.begin(), S_left.end(), mt);
shuffle(S_right.begin(), S_right.end(), mt);
int mid_left = S_left.size() / 2;
int mid_right = S_right.size() / 2;
vector<int> left_left(S_left.begin(), S_left.begin() + mid_left);
vector<int> left_right(S_left.begin() + mid_left, S_left.end());
vector<int> right_left(S_right.begin(), S_right.begin() + mid_right);
vector<int> right_right(S_right.begin() + mid_right, S_right.end());
if(safe_are_connected(left_left, right_left)) {
return find_one_edge(left_left, right_left, mt);
} else if(safe_are_connected(left_left, right_right)) {
return find_one_edge(left_left, right_right, mt);
} else if(safe_are_connected(left_right, right_left)) {
return find_one_edge(left_right, right_left, mt);
} else {
return find_one_edge(left_right, right_right, mt);
}
}
vector<int> longest_trip(int N, int D) {
assert(D >= 1);
vector<int> comps[2];
vector<pair<int, int>> edges;
vector<int> order(N);
iota(order.begin(), order.end(), 0);
mt19937 mt(42);
shuffle(order.begin(), order.end(), mt);
comps[0].push_back(order[0]);
int other = 1;
while(other < N && are_connected({order[0]}, {order[other]})) {
comps[0].push_back(order[other]);
edges.push_back({order[0], order[other]});
other++;
}
if(other == N) {
vector<int> all_nodes(N);
iota(all_nodes.begin(), all_nodes.end(), 0);
auto adj = create_adj(N, all_nodes, edges);
// cerr << "Solving tree" << endl;
return solve_tree(N, all_nodes, adj);
} else {
// Maybe we have two components
comps[1].push_back(order[other]);
// cerr << "Other: ";
// cerr << order[other] << endl;
for(int i = other + 1; i < N; i++) {
if(are_connected({order[0]}, {order[i]})) {
comps[0].push_back(order[i]);
edges.push_back({order[0], order[i]});
} else {
comps[1].push_back(order[i]);
edges.push_back({order[other], order[i]});
}
}
// cerr << "Comps:" << endl;
// for(int i = 0; i < 2; i++) {
// for(int j = 0; j < comps[i].size(); j++) {
// cerr << comps[i][j] << " ";
// }
// cerr << endl;
// }
// Check if this is one component
if(are_connected(comps[0], comps[1])) {
edges.push_back(find_one_edge(comps[0], comps[1], mt));
vector<int> all_nodes(N);
iota(all_nodes.begin(), all_nodes.end(), 0);
auto adj = create_adj(N, all_nodes, edges);
return solve_tree(N, all_nodes, adj);
}
// Two disjoint paths, so just get the longer one
if(comps[0].size() < comps[1].size()) swap(comps[0], comps[1]);
auto adj = create_adj(N, comps[0], edges);
return solve_tree(N, comps[0], adj);
}
}
Compilation message
longesttrip.cpp: In function 'std::pair<int, int> find_one_edge(std::vector<int>, std::vector<int>, std::mt19937&)':
longesttrip.cpp:130:2: error: 'assert' was not declared in this scope
130 | assert(S_left.size() >= 1 && S_right.size() >= 1);
| ^~~~~~
longesttrip.cpp:9:1: note: 'assert' is defined in header '<cassert>'; did you forget to '#include <cassert>'?
8 | #include "longesttrip.h"
+++ |+#include <cassert>
9 |
longesttrip.cpp: In function 'std::vector<int> longest_trip(int, int)':
longesttrip.cpp:159:2: error: 'assert' was not declared in this scope
159 | assert(D >= 1);
| ^~~~~~
longesttrip.cpp:159:2: note: 'assert' is defined in header '<cassert>'; did you forget to '#include <cassert>'?