답안 #1014474

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1014474 2024-07-05T02:30:51 Z shiomusubi496 Prize (CEOI22_prize) C++17
0 / 100
1838 ms 390008 KB
#include <bits/stdc++.h>

#define rep(i, n) for (int i = 0; i < (int)(n); ++i)
#define rep2(i, a, b) for (int i = (int)(a); i < (int)(b); ++i)
#define rrep(i, n) for (int i = (int)(n) - 1; i >= 0; --i)
#define all(v) begin(v), end(v)

using namespace std;

using ll = long long;

template<class T, class U> bool chmin(T& a, const U& b) { return a > b ? a = b, true : false; }
template<class T, class U> bool chmax(T& a, const U& b) { return a < b ? a = b, true : false; }

class LowestCommonAncestor {
    const vector<vector<int>>& G;
    vector<int> dep;
    vector<vector<int>> par;

    void dfs(int v, int p) {
        par[0][v] = p;
        for (int u : G[v]) {
            if (u == p) continue;
            dep[u] = dep[v] + 1;
            dfs(u, v);
        }
    }

public:
    LowestCommonAncestor(const vector<vector<int>>& G_, int r) : G(G_) {
        int N = G.size();
        dep.assign(N, 0);
        par.assign(22, vector<int>(N, -1));
        dfs(r, -1);
        rep (i, 21) rep (v, N) {
            par[i + 1][v] = par[i][v] == -1 ? -1 : par[i][par[i][v]];
        }
    }
    int lca(int a, int b) const {
        if (dep[a] > dep[b]) swap(a, b);
        rrep (i, 22) {
            if ((dep[b] - dep[a]) >> i & 1) b = par[i][b];
        }
        if (a == b) return a;
        rrep (i, 22) {
            if (par[i][a] != par[i][b]) {
                a = par[i][a];
                b = par[i][b];
            }
        }
        return par[0][a];
    }
    int depth(int v) const { return dep[v]; }
};

class WeightedUnionFind {
    vector<int> par;
    vector<ll> wei;

public:
    WeightedUnionFind(int n) : par(n, -1), wei(n, 0) {}
    int find(int v) {
        if (par[v] < 0) return v;
        int r = find(par[v]);
        wei[v] += wei[par[v]];
        return par[v] = r;
    }
    ll weight(int v) {
        find(v);
        return wei[v];
    }
    void merge(int u, int v, ll w) {
        w += weight(u);
        w -= weight(v);
        u = find(u);
        v = find(v);
        if (u == v) {
            assert(w == 0);
            return;
        }
        par[u] += par[v];
        par[v] = u;
        wei[v] = w;
    }
};

int main() {
    int N, K, Q, T; scanf("%d%d%d%d", &N, &K, &Q, &T);
    vector<int> P1(N), P2(N);
    int r1 = -1, r2 = -1;
    vector<vector<int>> G1(N), G2(N);
    rep (i, N) {
        scanf("%d", &P1[i]);
        if (P1[i] == -1) r1 = i;
        else G1[--P1[i]].push_back(i);
    }
    rep (i, N) {
        scanf("%d", &P2[i]);
        if (P2[i] == -1) r2 = i;
        else G2[--P2[i]].push_back(i);
    }
    vector<int> verts;
    {
        auto dfs = [&](auto&& self, int v) -> void {
            if ((int)verts.size() == K) return;
            verts.push_back(v);
            for (int u : G1[v]) self(self, u);
        };
        dfs(dfs, r1);
    }
    int raux;
    vector<int> aux;
    vector<int> Paux(N, -1);
    vector<vector<int>> Gaux(N);
    {
        // G2 における verts の auxiliary tree を作る
        vector<int> ord(N);
        {
            int cnt = 0;
            auto dfs = [&](auto&& self, int v) -> void {
                ord[v] = cnt++;
                for (int u : G2[v]) self(self, u);
            };
            dfs(dfs, r2);
        }
        LowestCommonAncestor lca(G2, r2);
        sort(all(verts), [&](int a, int b) { return ord[a] < ord[b]; });
        for (auto i : verts) {
            aux.push_back(i);
        }
        rep (i, K - 1) {
            aux.push_back(lca.lca(verts[i], verts[i + 1]));
        }
        sort(all(aux), [&](int a, int b) { return ord[a] < ord[b]; });
        aux.erase(unique(all(aux)), aux.end());
        raux = aux.front();
        stack<int> st;
        for (int v : aux) {
            if (st.empty()) {
                st.push(v);
                continue;
            }
            while (lca.lca(st.top(), v) != st.top()) st.pop();
            Gaux[st.top()].push_back(v);
            Paux[v] = st.top();
            st.push(v);
        }
    }
    vector<int> is_aux(N, -1);
    for (auto v : aux) is_aux[v] = 0;
    for (auto v : verts) is_aux[v] = 1;

    sort(all(verts));
    rep (i, K) {
        printf("%d", verts[i] + 1);
        if (i < K - 1) printf(" ");
        else printf("\n");
    }
    fflush(stdout);
    vector<pair<int, int>> qs;
    {
        auto dfs = [&](auto&& self, int v) -> int {
            int res = -1;
            if (is_aux[v] == 1) res = v;
            for (int u : Gaux[v]) {
                int r = self(self, u);
                if (res == -1) res = r;
                else {
                    qs.emplace_back(res, r);
                }
            }
            return res;
        };
        dfs(dfs, raux);
    }
    assert((int)qs.size() == K - 1);
    for (auto [a, b] : qs) {
        printf("? %d %d\n", a + 1, b + 1);
        fflush(stdout);
    }
    printf("!\n");
    fflush(stdout);
    LowestCommonAncestor lca1(G1, r1), lca2(G2, r2);
    WeightedUnionFind uf1(N), uf2(N);
    rep (i, qs.size()) {
        auto [a, b] = qs[i];
        ll d1a, d1b, d2a, d2b;
        scanf("%lld%lld%lld%lld", &d1a, &d1b, &d2a, &d2b);
        int l1 = lca1.lca(a, b), l2 = lca2.lca(a, b);
        uf1.merge(l1, a, d1a); uf1.merge(l1, b, d1b);
        uf2.merge(l2, a, d2a); uf2.merge(l2, b, d2b);
    }
    rep (_, T) {
        int a, b; scanf("%d%d", &a, &b);
        --a, --b;
        int l1 = lca1.lca(a, b), l2 = lca2.lca(a, b);
        int d1 = uf1.weight(a) + uf1.weight(b) - 2 * uf1.weight(l1);
        int d2 = uf2.weight(a) + uf2.weight(b) - 2 * uf2.weight(l2);
        printf("%d %d\n", d1, d2);
        fflush(stdout);
    }
}

Compilation message

Main.cpp: In function 'int main()':
Main.cpp:88:26: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
   88 |     int N, K, Q, T; scanf("%d%d%d%d", &N, &K, &Q, &T);
      |                     ~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~
Main.cpp:93:14: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
   93 |         scanf("%d", &P1[i]);
      |         ~~~~~^~~~~~~~~~~~~~
Main.cpp:98:14: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
   98 |         scanf("%d", &P2[i]);
      |         ~~~~~^~~~~~~~~~~~~~
Main.cpp:188:14: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  188 |         scanf("%lld%lld%lld%lld", &d1a, &d1b, &d2a, &d2b);
      |         ~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Main.cpp:194:24: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  194 |         int a, b; scanf("%d%d", &a, &b);
      |                   ~~~~~^~~~~~~~~~~~~~~~
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 740 ms 196292 KB Time limit exceeded (wall clock)
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 919 ms 198124 KB Time limit exceeded (wall clock)
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 643 ms 192224 KB Time limit exceeded (wall clock)
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 1637 ms 383928 KB Time limit exceeded (wall clock)
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 1838 ms 390008 KB Time limit exceeded (wall clock)
2 Halted 0 ms 0 KB -