#include <bits/stdc++.h>
using namespace std;
const long long inf = 1e18;
int n, timer;
vector<vector<pair<int, long long>>> adj;
vector<vector<int>> up;
vector<long long> dist;
vector<int> in_time, out_time;
vector<vector<int>> vt_adj;
vector<int> node_types, top_up, nodes, active_nodes;
void dfs_precalc(int u, int par) {
in_time[u] = ++timer;
up[u][0] = par;
for (int i = 1; i < 20; i++) {
if (up[u][i - 1] != -1)
up[u][i] = up[up[u][i - 1]][i - 1];
}
for (auto &[v, w] : adj[u]) {
if (v != par) {
dist[v] = dist[u] + w;
dfs_precalc(v, u);
}
}
out_time[u] = timer;
}
bool is_ancestor(int u, int v) {
return in_time[u] <= in_time[v] && out_time[v] <= out_time[u];
}
int lca(int u, int v) {
if (is_ancestor(u, v)) return u;
if (is_ancestor(v, u)) return v;
for (int i = 19; i >= 0; i--) {
if (up[u][i] != -1 && !is_ancestor(up[u][i], v))
u = up[u][i];
} return up[u][0];
}
void build(const vector<int> &type_0, const vector<int> &type_1) {
for (int u : active_nodes) {
vt_adj[u].clear();
node_types[u] = -1;
} active_nodes.clear();
top_up.clear();
nodes.clear();
auto mark = [&](int u) {
if (node_types[u] == -1) {
node_types[u] = -2;
active_nodes.push_back(u);
}
};
for (int x : type_0) {
nodes.push_back(x);
mark(x);
node_types[x] = 0;
} for (int x : type_1) {
nodes.push_back(x);
mark(x);
node_types[x] = 1;
} sort(nodes.begin(), nodes.end(), [&](int u, int v) {
return in_time[u] < in_time[v];
});
int sz = nodes.size();
for (int i = 0; i + 1 < sz; i++) {
int l = lca(nodes[i], nodes[i + 1]);
mark(l);
nodes.push_back(l);
} sort(nodes.begin(), nodes.end(), [&](int u, int v) {
return in_time[u] < in_time[v];
});
nodes.erase(unique(nodes.begin(), nodes.end()), nodes.end());
vector<int> stk;
for (int u : nodes) {
while (!stk.empty() && !is_ancestor(stk.back(), u)) stk.pop_back();
if (!stk.empty()) {
vt_adj[stk.back()].push_back(u);
vt_adj[u].push_back(stk.back());
} else top_up.push_back(u);
stk.push_back(u);
}
}
pair<long long, long long> dfs_vt(int u, int par, long long &res) {
pair<long long, long long> nearest = {inf, inf};
if (node_types[u] == 0) nearest.first = 0;
if (node_types[u] == 1) nearest.second = 0;
for (int v : vt_adj[u]) {
if (v == par) continue;
long long w = dist[v] - dist[u];
auto [z, o] = dfs_vt(v, u, res);
nearest.first = min(nearest.first, z + w);
nearest.second = min(nearest.second, o + w);
} res = min(res, nearest.first + nearest.second);
return nearest;
}
long long calc() {
long long res = inf;
for (int u : top_up) {
long long cur = inf;
dfs_vt(u, -1, cur);
res = min(res, cur);
} return res;
}
void Init(int N, int A[], int B[], int D[]) {
n = N;
timer = 0;
adj.assign(n + 1, {});
up.assign(n + 1, vector<int>(20, -1));
in_time.assign(n + 1, 0);
out_time.assign(n + 1, 0);
dist.assign(n + 1, 0);
vt_adj.assign(n + 1, {});
node_types.assign(n + 1, -1);
active_nodes.clear();
for (int i = 0; i < N - 1; i++) {
int u = A[i], v = B[i];
long long w = D[i];
adj[u].emplace_back(v, w);
adj[v].emplace_back(u, w);
} dfs_precalc(1, -1);
}
long long Query(int S, int X[], int T, int Y[]) {
vector<int> type_0(S), type_1(T);
for (int i = 0; i < S; i++) type_0[i] = X[i];
for (int i = 0; i < T; i++) type_1[i] = Y[i];
build(type_0, type_1);
return calc();
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |