이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include "factories.h"
#include <vector>
//#include <iostream>
#include <algorithm>
#include <set>
#include <stack>
using namespace std;
const int NMAX = 500000;
const int LGMAX = 19;
const long long INF = 1e18;
vector<pair<int, int> > g[NMAX + 5];
long long dist[NMAX + 5];
int lvl[NMAX + 5];
int firstAp[NMAX + 2];
vector<int> rmq[LGMAX + 2];
int log2[2 * NMAX + 2];
void dfs(int node, int parent = -1) {
    firstAp[node] = (int) rmq[0].size();
    rmq[0].push_back(node);
    for (auto it : g[node]) {
        if (it.first != parent) {
            lvl[it.first] = lvl[node] + 1;
            dist[it.first] = dist[node] + it.second;
            dfs(it.first, node);
            rmq[0].push_back(node);
        }
    }
}
int GetMinLvl(int A, int B) {
    if (lvl[A] < lvl[B]) {
        return A;
    }
    return B;
}
void BuildRmq() {
    log2[1] = 0;
    for (int i = 2; i <= 2 * NMAX; i++) {
        log2[i] = log2[i / 2] + 1;
    }
    for (int i = 1; i <= LGMAX; i++) {
        if ((1 << i) > (int) rmq[0].size()) {
            break;
        } else {
            for (int j = 0; j < (int) rmq[0].size(); j++) {
                if (j + (1 << i) > (int) rmq[0].size()) {
                    break;
                } else {
                    rmq[i].push_back(GetMinLvl(rmq[i - 1][j], rmq[i - 1][j + (1 << (i - 1))]));
                }
            }
        }
    }
}
int LCA(int A, int B) {
    A = firstAp[A];
    B = firstAp[B];
    if (A > B)
        swap(A, B);
    int k = log2[B - A + 1];
    return GetMinLvl(rmq[k][A], rmq[k][B - (1 << k) + 1]);
}
void Init(int N, int A[], int B[], int D[]) {
    for (int i = 1; i < N; i++) {
        g[A[i - 1]].push_back({B[i - 1], D[i - 1]});
        g[B[i - 1]].push_back({A[i - 1], D[i - 1]});
    }
    dfs(0);
    BuildRmq();
}
vector<pair<int, int>> g2[NMAX + 2];
long long dp1[NMAX + 2], dp2[NMAX + 2];
void computeDp1(int node, int parent = -1) {
    for (auto it : g2[node]) {
        if (it.first != parent) {
            computeDp1(it.first, node);
            dp1[node] = min(dp1[node], it.second + dp1[it.first]);
        }
    }
}
void computeDp2(int node, int parent = -1) {
    dp2[node] = min(dp1[node], dp2[node]);
    for (auto it : g2[node]) {
        if (it.first != parent) {
            dp2[it.first] = min(dp2[it.first], it.second + dp2[node]);
            computeDp2(it.first, node);
        }
    }
}
bool isAncestor(int A, int B) {
    return (A == LCA(A, B));
}
long long GetDist(int A, int B) {
    return dist[A] + dist[B] - 2 * dist[LCA(A, B)];
}
long long Solve(vector<int> &base, vector<int> &red, vector<int> &blue) {
    for (auto it : base) {
        g2[it].clear();
        dp1[it] = dp2[it] = INF;
    }
    for (auto it : red) {
        dp1[it] = 0;
    }
    stack<int> st;
    for (auto vert : base) {
        while (!st.empty() && isAncestor(st.top(), vert) == false) {
            st.pop();
        }
        if (!st.empty()) {
            long long dist = GetDist(st.top(), vert);
            g2[st.top()].push_back({vert, dist});
            g2[vert].push_back({st.top(), dist});
        }
        st.push(vert);
    }
    computeDp1(base[0]);
    computeDp2(base[0]);
    long long ans = INF;
    for (auto it : blue) {
        ans = min(ans, dp2[it]);
    }
    return ans;
}
inline bool cmp(const int A, const int B) {
    return firstAp[A] < firstAp[B];
}
long long Query(int S, int X[], int T, int Y[]) {
    vector<int> base;
    for (int i = 0; i < S; i++)
        base.push_back(X[i]);
    for (int i = 0; i < T; i++)
        base.push_back(Y[i]);
    sort(base.begin(), base.end(), cmp);
    vector<int> aux = base;
    for (int i = 1; i < (int) base.size(); i++)
        aux.push_back(LCA(base[i - 1], base[i]));
    set<int> vertices;
    for (auto it : aux)
        vertices.insert(it);
    base.clear();
    for (auto it : vertices)
        base.push_back(it);
    sort(base.begin(), base.end(), cmp);
    vector<int> red, blue;
    for (int i = 0; i < S; i++)
        red.push_back(X[i]);
    for (int i = 0; i < T; i++)
        blue.push_back(Y[i]);
    return Solve(base, red, blue);
}
컴파일 시 표준 에러 (stderr) 메시지
factories.cpp:23:5: warning: built-in function 'log2' declared as non-function [-Wbuiltin-declaration-mismatch]
   23 | int log2[2 * NMAX + 2];
      |     ^~~~| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... |