제출 #1172729

#제출 시각아이디문제언어결과실행 시간메모리
1172729anmattroi통행료 (IOI18_highway)C++17
69 / 100
97 ms17032 KiB
#include "highway.h"
#include <bits/stdc++.h>
#define fi first
#define se second
using namespace std;
using ii = pair<int, int>;


void find_pair_sub5(int N, vector<int> U, vector<int> V, int A, int B) {
    int M = U.size();

    mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

    function<int(int, int)> Rand = [&](int l, int r) {
        return uniform_int_distribution<int>(l, r)(rng);
    };

    function<vector<int>(vector<int>)> isolate = [&](vector<int> set2) {
        vector<int> cl(N, 0), orz(M, 0);
        for (int i : set2) cl[i] = 1;
        for (int i = 0; i < M; i++) {
            int u = U[i], v = V[i];
            orz[i] = (cl[u] == cl[v]);
        }
        return orz;
    };
    vector<int> set1, set2;
    while (1) {
        vector<int> s1, s2;
        for (int i = 0; i < N; i++)
        if (Rand(0, 1)) s1.emplace_back(i);
        else s2.emplace_back(i);

        if (ask(isolate(s1)) % 2 == 1) {
            set1 = s1;
            set2 = s2;
            break;
        }
    }
    while (set1.size() > 1) {
        int mid = (set1.size())/2;
        if (ask(isolate(vector<int>(set1.begin(), set1.begin()+mid))) % 2 == 1) set1 = vector<int>(set1.begin(), set1.begin()+mid);
        else set1 = vector<int>(set1.begin()+mid, set1.end());
    }
    while (set2.size() > 1) {
        int mid = (set2.size())/2;
        if (ask(isolate(vector<int>(set2.begin(), set2.begin()+mid))) % 2 == 1) set2 = vector<int>(set2.begin(), set2.begin()+mid);
        else set2 = vector<int>(set2.begin()+mid, set2.end());
    }
    answer(set1[0], set2[0]);
}


void find_pair(int N, vector<int> U, vector<int> V, int A, int B) {
    if (A == 1 && B == 2) {
        find_pair_sub5(N, U, V, A, B);
        return;
    }
    vector<vector<ii> > adj(N);

    int M = U.size();

    for (int i = 0; i < M; i++) {
        adj[U[i]].emplace_back(V[i], i);
        adj[V[i]].emplace_back(U[i], i);
    }

    int64_t all_b = ask(vector<int>(M, 1));
    int64_t all_a = all_b / B * A;

    int lo = -1, hi = M-1;
    while (hi - lo > 1) {
        int mid = (lo + hi) >> 1;
        vector<int> orz(M, 0);
        for (int i = 0; i <= mid; i++) orz[i] = 1;
        if (ask(orz) == all_b) hi = mid;
        else lo = mid;
    }


    int X = U[hi], Y = V[hi];

    if (1) {
        vector<int> orz(M, 0);
        orz[hi] = 1;
        assert(ask(orz) != all_a);
    }
    vector<int> depthX(N, 0), depthY(N, 0), peX(N, -1), peY(N, -1);

    function<void(int, int)> pfsX = [&](int u, int dad) {
        for (auto [v, idx] : adj[u])
        if (v != dad) {
            depthX[v] = depthX[u] + 1;
            peX[v] = idx;
            pfsX(v, u);
        }
    };

    function<void(int, int)> pfsY = [&](int u, int dad) {
        for (auto [v, idx] : adj[u])
        if (v != dad) {
            depthY[v] = depthY[u] + 1;
            peY[v] = idx;
            pfsY(v, u);
        }
    };


    pfsX(X, -1);
    pfsY(Y, -1);

    vector<int> lisX(1, X), lisY(1, Y);

    for (int i = 0; i < N; i++)
    if (i != X && i != Y) {
        if (depthX[i] < depthY[i]) lisX.emplace_back(i);
        else lisY.emplace_back(i);
    }
    int S = X, T = Y;

    if (1) {
        sort(lisX.begin(), lisX.end(), [&](int x, int y) {return depthX[x] > depthX[y];});

        int lo = -1, hi = lisX.size() - 1;
        while (hi - lo > 1) {
            int mid = (lo + hi) >> 1;
            vector<int> orz(M, 0);
            for (int i = 0; i <= mid; i++) orz[peX[lisX[i]]] = 1;
            if (ask(orz) != all_a) hi = mid;
            else lo = mid;
        }
        S = lisX[hi];
    }
    if (1) {
        sort(lisY.begin(), lisY.end(), [&](int x, int y) {return depthY[x] > depthY[y];});

        int lo = -1, hi = lisY.size() - 1;
        while (hi - lo > 1) {
            int mid = (lo + hi) >> 1;
            vector<int> orz(M, 0);
            for (int i = 0; i <= mid; i++) orz[peY[lisY[i]]] = 1;
            if (ask(orz) != all_a) hi = mid;
            else lo = mid;
        }
        T = lisY[hi];
    }
    answer(S, T);
}

/*
4 3
1 3 1 3
1 2
2 0
0 3
*/
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...