#include "sphinx.h"
#include <bits/stdc++.h>
using namespace std;
// We use a map to cache experiment results to save queries
map<array<int, 4>, int> experiment_cache;
/**
* count_target_color_in_range
* * This function calculates how many vertices in the range [left, right]
* that match the 'parity' (even or odd) have the color 'target_color'.
*/
int count_target_color_in_range(int n, int left, int right, int target_color, int parity) {
// Base case: if range is a single point, check if it matches the parity
if (left == right && (left % 2) != (parity % 2)) {
return 0;
}
// Check cache to avoid redundant Sphinx calls
array<int, 4> state = {left, right, target_color, parity};
if (experiment_cache.count(state)) {
return experiment_cache[state];
}
vector<int> experiment_vec(n, target_color);
int released_count = 0;
for (int i = left; i <= right; i++) {
if ((i % 2) == (parity % 2)) {
experiment_vec[i] = -1; // Release to its hidden color
released_count++;
}
}
if (released_count == 0) return 0;
int components = perform_experiment(experiment_vec);
/* On a path graph:
If we force everything to target_color, we start with 1 component.
Every independent node we release that is NOT the target_color
splits the 'sea' and adds exactly 2 to the component count.
*/
int non_matching = (components - 1) / 2;
int matching_count = released_count - non_matching;
return experiment_cache[state] = matching_count;
}
vector<int> find_colours(int n, vector<int> x, vector<int> y) {
vector<int> final_colors(n);
// 1. Find the color of the first node (index 0)
for (int c = 0; c < n; c++) {
vector<int> experiment_vec(n, c);
experiment_vec[0] = -1;
if (perform_experiment(experiment_vec) == 1) {
final_colors[0] = c;
break;
}
}
// 2. Find the color of the last node (index n-1)
for (int c = 0; c < n; c++) {
vector<int> experiment_vec(n, c);
experiment_vec[n - 1] = -1;
if (perform_experiment(experiment_vec) == 1) {
final_colors[n - 1] = c;
break;
}
}
// 3. Process even and odd indices separately for the middle nodes
for (int parity = 0; parity < 2; parity++) {
for (int color = 0; color < n; color++) {
int total_to_find = count_target_color_in_range(n, 1, n - 2, color, parity);
int found_so_far = 0;
int last_found_pos = 0;
// Use binary search to locate each instance of this color
while (found_so_far < total_to_find) {
int low = last_found_pos + 1;
int high = n - 2;
int result_pos = -1;
while (low <= high) {
int mid = low + (high - low) / 2;
if (count_target_color_in_range(n, last_found_pos + 1, mid, color, parity) > 0) {
result_pos = mid;
high = mid - 1;
} else {
low = mid + 1;
}
}
if (result_pos != -1) {
final_colors[result_pos] = color;
found_so_far++;
last_found_pos = result_pos;
} else {
break;
}
}
}
}
return final_colors;
}