#include <bits/stdc++.h>
using namespace std;
const int MAXN = 500000 + 5;
int tin[MAXN], tout[MAXN], timer;
int dep[MAXN], dtr[MAXN];
vector<pair<int,int>> adj[MAXN];
int lg;
vector<vector<int>> up;
void dfs(int u, int p) {
tin[u] = ++timer;
up[u][0] = p;
for (int i = 1; i <= lg; ++i) {
int mid = up[u][i-1];
if (mid < 0) break;
up[u][i] = up[mid][i-1];
}
for (auto &e : adj[u]) {
int v = e.first, w = e.second;
if (v == p) continue;
dep[v] = dep[u] + 1;
dtr[v] = dtr[u] + w;
dfs(v, u);
}
tout[u] = timer;
}
int lca(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
int diff = dep[u] - dep[v];
for (int i = 0; diff; ++i) {
if (diff & 1) u = up[u][i];
diff >>= 1;
}
if (u == v) return u;
for (int i = lg; i >= 0; --i) {
if (up[u][i] != up[v][i]) {
u = up[u][i];
v = up[v][i];
}
}
return up[u][0];
}
long long solve(int a[], int n, int b[], int m) {
// 1) collect nodes and add LCA of original neighbors
vector<int> nodes;
nodes.insert(nodes.end(), a, a + n);
nodes.insert(nodes.end(), b, b + m);
sort(nodes.begin(), nodes.end(), [&](int u, int v){ return tin[u] < tin[v]; });
int orig = nodes.size();
for (int i = 1; i < orig; ++i) {
nodes.push_back(lca(nodes[i-1], nodes[i]));
}
sort(nodes.begin(), nodes.end(), [&](int u, int v){ return tin[u] < tin[v]; });
nodes.erase(unique(nodes.begin(), nodes.end()), nodes.end());
int K = nodes.size();
// 2) map and build virtual tree
unordered_map<int,int> idx;
idx.reserve(K);
for (int i = 0; i < K; ++i) idx[nodes[i]] = i;
vector<vector<pair<int,int>>> vt(K);
vector<int> st;
for (int u : nodes) {
int uid = idx[u];
while (!st.empty() && !(tin[nodes[st.back()]] <= tin[u] && tout[u] <= tout[nodes[st.back()]]))
st.pop_back();
if (!st.empty()) {
int pid = st.back();
int p = nodes[pid];
int w = dtr[u] - dtr[p];
vt[pid].emplace_back(uid, w);
vt[uid].emplace_back(pid, w);
}
st.push_back(uid);
}
// 3) DP on virtual tree
const long long INF = LLONG_MAX / 4;
vector<long long> da(K, INF), db(K, INF);
for (int i = 0; i < n; ++i) da[idx[a[i]]] = 0;
for (int i = 0; i < m; ++i) db[idx[b[i]]] = 0;
long long ans = INF;
function<void(int,int)> dfs1 = [&](int u, int p) {
for (auto &pr : vt[u]) {
int v = pr.first, w = pr.second;
if (v == p) continue;
dfs1(v, u);
da[u] = min(da[u], da[v] + w);
db[u] = min(db[u], db[v] + w);
}
ans = min(ans, da[u] + db[u]);
};
dfs1(st[0], -1);
function<void(int,int)> dfs2 = [&](int u, int p) {
for (auto &pr : vt[u]) {
int v = pr.first, w = pr.second;
if (v == p) continue;
da[v] = min(da[v], da[u] + w);
db[v] = min(db[v], db[u] + w);
ans = min(ans, da[v] + db[v]);
dfs2(v, u);
}
};
dfs2(st[0], -1);
return ans;
}
void Init(int n, int U[], int V[], int W[]) {
for (int i = 0; i < n; ++i) adj[i].clear();
for (int i = 0; i < n-1; ++i) {
adj[U[i]].emplace_back(V[i], W[i]);
adj[V[i]].emplace_back(U[i], W[i]);
}
timer = 0;
fill(dep, dep+n, 0);
fill(dtr, dtr+n, 0);
lg = floor(log2(max(1, n)));
up.assign(n, vector<int>(lg+1, -1));
dfs(0, -1);
}
long long Query(int s, int A[], int t, int B[]) {
return solve(A, s, B, t);
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |