제출 #765873

#제출 시각아이디문제언어결과실행 시간메모리
765873t6twotwo통행료 (IOI18_highway)C++17
51 / 100
143 ms16612 KiB
#include "highway.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
void find_pair(int N, vector<int> U, vector<int> V, int A, int B) {
    int M = U.size(), lo, hi;
    ll s = ask(vector<int>(M));
    vector<vector<pair<int, int>>> adj(N);
    for (int i = 0; i < M; i++) {
        adj[U[i]].emplace_back(V[i], i);
        adj[V[i]].emplace_back(U[i], i);
    }
    if (M == N - 1) {
        vector<int> dep(N), t(N);
        auto dfs = [&](auto dfs, int x, int p) -> void {
            for (auto [y, z] : adj[x]) {
                if (y == p) {
                    continue;
                }
                t[y] = z;
                dep[y] = dep[x] + 1;
                dfs(dfs, y, x);
            }
        };
        dfs(dfs, 0, -1);
        vector<vector<int>> v(N);
        for (int i = 0; i < N; i++) {
            v[dep[i]].push_back(i);
        }
        lo = 1, hi = N - 1;
        while (lo < hi) {
            int mi = (lo + hi) / 2;
            vector<int> w(M, 1);
            for (int i = 1; i <= mi; i++) {
                for (int x : v[i]) {
                    w[t[x]] = 0;
                }
            }
            if (ask(w) == s) {
                hi = mi;
            } else {
                lo = mi + 1;
            }
        }
        int d = lo; lo = 0, hi = v[d].size() - 1;
        while (lo < hi) {
            int mi = (lo + hi) / 2;
            vector<int> w(M);
            for (int i = 0; i <= mi; i++) {
                w[t[v[d][i]]] = 1;
            }
            if (ask(w) != s) {
                hi = mi;
            } else {
                lo = mi + 1;
            }
        }
        int S = v[d][lo];
        dep[S] = 0;
        dfs(dfs, S, -1);
        vector<int> q;
        for (int i = 0; i < N; i++) {
            if (dep[i] == s / A) {
                q.push_back(i);
            }
        }
        lo = 0, hi = q.size() - 1;
        while (lo < hi) {
            int mi = (lo + hi) / 2;
            vector<int> w(M);
            for (int i = 0; i <= mi; i++) {
                w[t[q[i]]] = 1;
            }
            if (ask(w) != s) {
                hi = mi;
            } else {
                lo = mi + 1;
            }
        }
        answer(S, q[lo]);
        return;
    }
    vector<int> W(M);
    lo = 0, hi = M - 1;
    while (lo < hi) {
        int mi = (lo + hi) / 2;
        auto w = W;
        for (int i = 0; i <= mi; i++) {
            w[i] = 1;
        }
        if (ask(w) == s) {
            W = w;
            lo = mi + 1;
        } else {
            hi = mi;
        }
    }
    auto bfs = [&](int s) {
        vector<int> dis(N, -1), f(N, -1);
        dis[s] = 0;
        queue<int> q;
        q.push(s);
        while (!q.empty()) {
            int x = q.front();
            q.pop();
            for (auto [y, z] : adj[x]) {
                if (dis[y] == -1) {
                    dis[y] = dis[x] + 1;
                    f[y] = z;
                    q.push(y);
                }
            }
        }
        return make_pair(dis, f);
    };
    int P = U[lo], Q = V[lo];
    auto [d0, f0] = bfs(P);
    auto [d1, f1] = bfs(Q);
    vector<int> w(M, 1);
    for (int i = 0; i < N; i++) {
        if (d0[i] < d1[i]) {
            int x = i;
            while (x != P) {
                w[f0[x]] = 0;
                x ^= U[f0[x]] ^ V[f0[x]];
            }
        }
    }
    int t = (s + (s / A) * (B - A) - ask(w)) / (B - A);
    vector<int> cand;
    for (int i = 0; i < N; i++) {
        if (d0[i] == t && d0[i] < d1[i]) {
            cand.push_back(i);
        }
    }
    lo = 0, hi = cand.size() - 1;
    while (lo < hi) {
        int mi = (lo + hi) / 2;
        vector<int> w(M, 1);
        for (int i = 0; i <= mi; i++) {
            int x = cand[i];
            while (x != P) {
                w[f0[x]] = 0;
                x ^= U[f0[x]] ^ V[f0[x]];
            }
        }
        if (ask(w) == s + 1ll * (s / A - t) * (B - A)) {
            hi = mi;
        } else {
            lo = mi + 1;
        }
    }
    int S = cand[lo];
    cand.clear();
    for (int i = 0; i < N; i++) {
        if (d1[i] == s / A - t - 1 && d1[i] < d0[i]) {
            cand.push_back(i);
        }
    }
    lo = 0, hi = cand.size() - 1;
    while (lo < hi) {
        int mi = (lo + hi) / 2;
        vector<int> w(M, 1);
        for (int i = 0; i <= mi; i++) {
            int x = cand[i];
            while (x != Q) {
                w[f1[x]] = 0;
                x ^= U[f1[x]] ^ V[f1[x]];
            }
        }
        if (ask(w) == s + 1ll * (t + 1) * (B - A)) {
            hi = mi;
        } else {
            lo = mi + 1;
        }
    }
    answer(S, cand[lo]);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...