제출 #713869

#제출 시각아이디문제언어결과실행 시간메모리
713869PixelCat통행료 (IOI18_highway)C++14
12 / 100
411 ms262144 KiB
#ifdef NYAOWO
#include "grader.cpp"
#endif
#include "highway.h"

#include <bits/stdc++.h>
#define For(i, a, b) for(int i = a; i <= b; i++)
#define F first
#define S second
#define all(x) x.begin(), x.end()
#define sz(x) ((int)x.size())
#define eb emplace_back
using namespace std;
using LL = long long;
using pii = pair<int, int>;

const int MAXN = 90010;

vector<int> adj[MAXN];
int d[MAXN];
int eid[MAXN];

void dfs(int n, int p, int dep) {
    d[n] = dep;
    for(auto &i:adj[n]) if(i != p) {
        dfs(i, n, dep + 1);
    }
}

LL query(const vector<int> &v, int M) {
    vector<int> w(M, 0);
    for(auto &i:v) w[eid[i]] = 1;
    return ask(w);
}

// lowest endpoint > dep ?
bool check(int dep, int N, int M, LL dist) {
    vector<int> w(M, 0);
    For(i, 0, N) if(d[i] >= dep) {
        w[eid[i]] = 1;
    }
    return ask(w) > dist;
}

int solve(int N, int M, vector<int> &cand, LL dist) {
    int m = sz(cand);
    if(m == 1) return cand[0];
    int mi = m / 2;
    vector<int> v1(cand.begin(), cand.begin() + mi);
    vector<int> v2(cand.begin() + mi, cand.end());
    if(query(v1, M) > dist) return solve(N, M, v1, dist);
    return solve(N, M, v2, dist);
}

void find_pair(int N, std::vector<int> U, std::vector<int> V, int A, int B) {
    cerr << A << " " << B << "\n";
    int M = U.size();
    // assert(M == N - 1);
    LL dist = ask(vector<int>(M, 0));
    assert(dist % A == 0);

    For(i, 0, M - 1) {
        adj[U[i]].eb(V[i]);
        adj[V[i]].eb(U[i]);
    }

    dfs(0, 0, 0);
    For(i, 0, M - 1) {
        if(d[U[i]] < d[V[i]]) eid[V[i]] = i;
        else eid[U[i]] = i;
    }
    int hi = N + 1, lo = 0;
    while(hi - lo > 1) {
        int mi = (hi + lo) / 2;
        if(check(mi, N, M, dist)) lo = mi;
        else hi = mi;
    }
    vector<int> cand;
    For(i, 0, N - 1) if(d[i] == lo) cand.eb(i);
    int rt = solve(N, M, cand, dist);
    answer(0, rt);

    // dfs(rt, rt, 0);
    // For(i, 0, M - 1) {
    //     if(d[U[i]] < d[V[i]]) eid[V[i]] = i;
    //     else eid[U[i]] = i;
    // }

    // cand.clear();
    // For(i, 0, N - 1) if(d[i] == dist / A) cand.eb(i);
    // answer(rt, solve(N, M, cand, dist));

    // for (int j = 0; j < 50; ++j) {
    //     std::vector<int> w(M);
    //     for (int i = 0; i < M; ++i) {
    //         w[i] = 0;
    //     }
    //     long long toll = ask(w);
    // }
    // answer(0, N - 1);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...