#include "scales.h"
#include<bits/stdc++.h>
using namespace std;
vector<int> v = {0, 1, 2, 3, 4, 5};
vector<vector<int>> qs;
unordered_set<int> us;
int query(int t, int a, int b, int c, int d=-1){
    if (t == 0){
        if (v[a] > v[b] && v[a] > v[c]) return a;
        else if (v[b] > v[a] && v[b] > v[c]) return b;
        else return c;
    }
    if (t == 1){
        if (v[a] < v[b] && v[a] < v[c]) return a;
        else if (v[b] < v[a] && v[b] < v[c]) return b;
        else return c;
    }
    if (t == 2){
        if (v[a] > min(v[b], v[c]) && v[a] < max(v[b], v[c])) return a;
        else if (v[b] > min(v[a], v[c]) && v[b] < max(v[a], v[c])) return b;
        else return c;
    }
    if (t == 3){
        if (v[a] > v[d] && (v[b] > v[a] || v[b] < v[d]) && (v[c] > v[a] || v[c] < v[d])) return a;
        else if (v[b] > v[d] && (v[a] > v[b] || v[a] < v[d]) && (v[c] > v[b] || v[c] < v[d])) return b;
        else if (v[c] > v[d] && (v[a] > v[c] || v[a] < v[d]) && (v[b] > v[c] || v[b] < v[d])) return c;
        else return query(1, a, b, c);
    }
}
int cnt(bool b = false){
    int res = 0;
    for (int x : us){
        v = {x%10, (x/10)%10, (x/100)%10, (x/1000)%10, (x/10000)%10, x/100000};
        bool valid = true;
        for (vector<int> q : qs){
            if (query(q[0], q[1], q[2], q[3], q[4]) != q[5]){
                valid = false;
                if (b) us.erase(x);
                break;
            }
        }
        if (valid) res++;
    }
    return res;
}
void init(int T){
    
}
void orderCoins(){
    us.clear();
    vector<int> w = {0, 1, 2, 3, 4, 5};
    while (true){
        us.insert(w[0]+w[1]*10+w[2]*100+w[3]*1000+w[4]*10000+w[5]*100000);
        if (!next_permutation(w.begin(), w.end())) break;
    }
    qs.resize(0);
    while (true){
        int bsf = 1000;
        vector<int> bq;
        for (int i=0; i<4; i++){
            for (int j=i+1; j<5; j++){
                for (int k=j+1; k<6; k++){
                    for (int l=0; l<3; l++){
                        int cur = 0;
                        qs.push_back({l, i, j, k, -1, i});
                        cur = max(cur, cnt());
                        qs.pop_back();
                        qs.push_back({l, i, j, k, -1, j});
                        cur = max(cur, cnt());
                        qs.pop_back();
                        qs.push_back({l, i, j, k, -1, k});
                        cur = max(cur, cnt());
                        qs.pop_back();
                        if (cur < bsf){
                            bsf = cur;
                            bq = {l, i, j, k, -1};
                        }
                    }
                    for (int l=0; l<6; l++){
                        if (l == i || l == j || l == k) continue;
                        int cur = 0;
                        qs.push_back({3, i, j, k, l, i});
                        cur = max(cur, cnt());
                        qs.pop_back();
                        qs.push_back({3, i, j, k, l, j});
                        cur = max(cur, cnt());
                        qs.pop_back();
                        qs.push_back({3, i, j, k, l, k});
                        cur = max(cur, cnt());
                        qs.pop_back();
                        if (cur < bsf){
                            bsf = cur;
                            bq = {3, i, j, k, l};
                        }
                    }
                }
            }
        }
        int ans;
        assert(bq.size() == 5);
        assert(bq[0] >= 0 && bq[0] <= 3);
        if (bq[0] == 0) ans = getHeaviest(bq[1]+1, bq[2]+1, bq[3]+1);
        if (bq[0] == 1) ans = getLightest(bq[1]+1, bq[2]+1, bq[3]+1);
        if (bq[0] == 2) ans = getMedian(bq[1]+1, bq[2]+1, bq[3]+1);
        if (bq[0] == 3) ans = getNextLightest(bq[1]+1, bq[2]+1, bq[3]+1, bq[4]+1);
        assert(ans >= 1 && ans <= 6);
        qs.push_back({bq[0], bq[1], bq[2], bq[3], bq[4], ans-1});
        if (cnt(true) == 1) break;
    }
    v = {0, 1, 2, 3, 4, 5};
    while (true){
        bool valid = true;
        for (vector<int> q : qs){
            if (query(q[0], q[1], q[2], q[3], q[4]) != q[5]){
                valid = false;
                break;
            }
        }
        if (valid){
            int arr[6];
            for (int i=0; i<6; i++) arr[v[i]] = i+1;
            answer(arr);
            return;
        }
        if (!next_permutation(v.begin(), v.end())) break;
    }
}
Compilation message (stderr)
scales.cpp: In function 'int query(int, int, int, int, int)':
scales.cpp:31:1: warning: control reaches end of non-void function [-Wreturn-type]
   31 | }
      | ^| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... |