Submission #1317223

#TimeUsernameProblemLanguageResultExecution timeMemory
1317223starplatinumHighway Tolls (IOI18_highway)C++17
100 / 100
134 ms10424 KiB
#include "highway.h"
#include <bits/stdc++.h>
using namespace std;

namespace {
    int N, M;
    vector<int> U, V;
    vector<vector<int>> adj;
    vector<int> w;
    long long D0;

    vector<int> bfs(int src) {
        vector<int> dist(N, -1);
        queue<int> q;
        dist[src] = 0;
        q.push(src);
        while (!q.empty()) {
            int u = q.front(); q.pop();
            for (int v : adj[u]) {
                if (dist[v] == -1) {
                    dist[v] = dist[u] + 1;
                    q.push(v);
                }
            }
        }
        return dist;
    }

    bool has_edge_prefix(int k) {
        // edges [0..k] light (0), others heavy (1)
        for (int i = 0; i < M; ++i) w[i] = (i <= k ? 0 : 1);
        return ask(w) == D0;
    }

    struct SideSearchContext {
        // side[x] = 0 if closer to u, 1 if closer to v (ties assigned to 0)
        vector<int> side;
        vector<int> ord[2];
        vector<int> pos0, pos1; // positions within ord[0] / ord[1] (only valid for that side)
    };

    bool has_side_prefix(const SideSearchContext& ctx, int targetSide, int k) {
        // Always include ALL vertices of the opposite side.
        // Include first k vertices of targetSide order.
        for (int i = 0; i < M; ++i) {
            int a = U[i], b = V[i];

            auto inSet = [&](int x) -> bool {
                int s = ctx.side[x];
                if (targetSide == 0) {
                    if (s == 1) return true;                // opposite side fully included
                    return ctx.pos0[x] < k;                 // prefix of side 0
                } else {
                    if (s == 0) return true;                // opposite side fully included
                    return ctx.pos1[x] < k;                 // prefix of side 1
                }
            };

            w[i] = (inSet(a) && inSet(b)) ? 0 : 1;
        }
        return ask(w) == D0;
    }

    int find_endpoint_in_side(const SideSearchContext& ctx, int targetSide) {
        const vector<int>& ord = ctx.ord[targetSide];
        int sz = (int)ord.size();

        // We know:
        // k=0 => missing the endpoint on this side => false
        // k=sz => all vertices included => true
        int lo = 0, hi = sz;
        while (hi - lo > 1) {
            int mid = lo + (hi - lo) / 2;
            if (has_side_prefix(ctx, targetSide, mid)) hi = mid;
            else lo = mid;
        }
        return ord[hi - 1];
    }
}

void find_pair(int n, vector<int> u, vector<int> v, int A, int B) {
    (void)A; (void)B;

    N = n;
    U = std::move(u);
    V = std::move(v);
    M = (int)U.size();

    adj.assign(N, {});
    adj.reserve(N);
    for (int i = 0; i < M; ++i) {
        adj[U[i]].push_back(V[i]);
        adj[V[i]].push_back(U[i]);
    }

    w.assign(M, 0);
    D0 = ask(w); // all light

    // Step 1: find an edge on some shortest S-T path (monotone edge-prefix search).
    int lo = -1, hi = M - 1; // hi is true because hi means all edges light => ask == D0
    while (hi - lo > 1) {
        int mid = lo + (hi - lo) / 2;
        if (has_edge_prefix(mid)) hi = mid;
        else lo = mid;
    }
    int e = hi;
    int u0 = U[e], v0 = V[e];

    // BFS from both endpoints of that edge
    vector<int> distU = bfs(u0);
    vector<int> distV = bfs(v0);

    // Build partition by which endpoint is closer (ties arbitrary)
    SideSearchContext ctx;
    ctx.side.assign(N, 0);
    ctx.ord[0].clear();
    ctx.ord[1].clear();
    ctx.ord[0].reserve(N);
    ctx.ord[1].reserve(N);

    for (int x = 0; x < N; ++x) {
        if (distU[x] < distV[x]) ctx.side[x] = 0;
        else if (distV[x] < distU[x]) ctx.side[x] = 1;
        else ctx.side[x] = 0; // tie -> side 0
        ctx.ord[ctx.side[x]].push_back(x);
    }

    // Sort each side by distance to its anchor, then id
    sort(ctx.ord[0].begin(), ctx.ord[0].end(), [&](int a, int b) {
        if (distU[a] != distU[b]) return distU[a] < distU[b];
        return a < b;
    });
    sort(ctx.ord[1].begin(), ctx.ord[1].end(), [&](int a, int b) {
        if (distV[a] != distV[b]) return distV[a] < distV[b];
        return a < b;
    });

    ctx.pos0.assign(N, INT_MAX);
    ctx.pos1.assign(N, INT_MAX);
    for (int i = 0; i < (int)ctx.ord[0].size(); ++i) ctx.pos0[ctx.ord[0][i]] = i;
    for (int i = 0; i < (int)ctx.ord[1].size(); ++i) ctx.pos1[ctx.ord[1][i]] = i;

    // Step 2+3: find one endpoint in each side (total queries <= log|side0| + log|side1| <= 32)
    int a = find_endpoint_in_side(ctx, 0);
    int b = find_endpoint_in_side(ctx, 1);

    answer(a, b);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...