Submission #775729

#TimeUsernameProblemLanguageResultExecution timeMemory
775729GusterGoose27Highway Tolls (IOI18_highway)C++17
5 / 100
145 ms29020 KiB
#include "highway.h"

#include <bits/stdc++.h>

const int MAXN = 9e4;

using namespace std;
typedef pair<int, int> pii;

vector<int> weight;
int n, m;
vector<pii> edges[MAXN];
int base;
bool vis[MAXN];
int sz[MAXN];
int a, b;

void make_sz(int cur, int p = -1) {
    sz[cur] = 1;
    for (pii e: edges[cur]) {
        int nxt = e.first;
        if (nxt == p || vis[nxt]) continue;
        make_sz(nxt, cur);
        sz[cur] += sz[nxt];
    }
}

vector<int> res;

void deactivate(int v) {
    for (pii e: edges[v]) weight[e.second] = 1;
}

void activate(int v) {
    for (pii e: edges[v]) weight[e.second] = 0;
}

int find(int cur, int dest, int p = -1) {
    for (pii e: edges[cur]) {
        if (e.first == p) continue;
        if (e.first == dest) return cur;
        int m = find(e.first, dest, cur);
        if (m != -1) return m;
    }
    return -1;
}

void set_sub(int cur, bool v, int p = -1) {
    if (v) deactivate(cur);
    else activate(cur);
    for (pii e: edges[cur]) {
        if (vis[e.first] || e.first == p) continue;
        set_sub(e.first, v, cur);
    }
}

void solve(int cur, int num = 2, int oth = -1) {
    if (num == 0) return;
    make_sz(cur);
    int tot = sz[cur];
    if (tot == 1) {
        assert(num == 1);
        res.push_back(cur);
        vis[cur] = 1;
        return;
    }
    int p = -1;
    bool f = 0;
    while (!f) {
        f = 1;
        for (pii e: edges[cur]) {
            int nxt = e.first;
            if (vis[nxt] || nxt == p) continue;
            if (2*sz[nxt] >= tot) {
                f = 0;
                p = cur;
                cur = nxt;
                break;
            }
        }
    }
    vis[cur] = 1;
    deactivate(cur);
    int val = ask(weight);
    activate(cur);
    if (val == base+b-a) {
        res.push_back(cur);
        val = base;
        num--;
        oth = cur;
    }
    if (val == base) { // figure out which subtree
        vector<pii> adj;
        for (pii e: edges[cur]) {
            if (vis[e.first]) continue;
            adj.push_back(e);
        }
        int l = 0;
        int r = adj.size();
        assert(adj.size() >= 1);
        while (r > l+1) {
            int mid = (l+r)/2;
            for (int i = l; i < mid; i++) set_sub(adj[i].first, 1);
            if (ask(weight) > base) {
                r = mid;
            }
            else {
                l = mid;
            }
            for (int i = l; i < mid; i++) set_sub(adj[i].first, 0);
        }
        solve(adj[l].first, num, oth);
    }
    else {
        assert(val == base+2*(b-a));
        while (num > 0) {
            int ig_dir = -1;
            if (num == 1) ig_dir = find(oth, cur);
            vector<pii> adj;
            for (pii e: edges[cur]) {
                if (vis[e.first] || e.first == ig_dir) continue;
                adj.push_back(e);
            }
            int l = 0;
            int r = adj.size();
            assert(adj.size() >= num);
            while (r > l+1) {
                int mid = (l+r)/2;
                for (int i = l; i < mid; i++) weight[adj[i].second] = 1;
                if (ask(weight) > base) {
                    r = mid;
                }
                else {
                    l = mid;
                }
                for (int i = l; i < mid; i++) weight[adj[i].second] = 0;
            }
            solve(adj[l].first, 1, cur);
            num--;
            oth = res[0];
        }
    }
}

void find_pair(int N, vector<int> U, vector<int> V, int A, int B) {
    a = A; b = B;
    assert(U.size() == N-1);
    n = N;
    m = U.size();
    for (int i = 0; i < m; i++) {
        edges[U[i]].push_back(pii(V[i], i));
        edges[V[i]].push_back(pii(U[i], i));
    }
    weight = vector<int>(m);
    fill(weight.begin(), weight.end(), 0);
    base = ask(weight);
    solve(0);
    answer(res[0], res[1]);
}

Compilation message (stderr)

In file included from /usr/include/c++/10/cassert:44,
                 from /usr/include/x86_64-linux-gnu/c++/10/bits/stdc++.h:33,
                 from highway.cpp:3:
highway.cpp: In function 'void solve(int, int, int)':
highway.cpp:126:31: warning: comparison of integer expressions of different signedness: 'std::vector<std::pair<int, int> >::size_type' {aka 'long unsigned int'} and 'int' [-Wsign-compare]
  126 |             assert(adj.size() >= num);
      |                    ~~~~~~~~~~~^~~~~~
highway.cpp: In function 'void find_pair(int, std::vector<int>, std::vector<int>, int, int)':
highway.cpp:147:21: warning: comparison of integer expressions of different signedness: 'std::vector<int>::size_type' {aka 'long unsigned int'} and 'int' [-Wsign-compare]
  147 |     assert(U.size() == N-1);
      |            ~~~~~~~~~^~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...