#include<bits/stdc++.h>
#include "factories.h"
using namespace std;
#define ll long long
const int maxn = 5e5 + 5;
vector<pair<int, int>> g[maxn];
vector<pair<int, ll>> adj[maxn];
int h[maxn], up[maxn][22], tin[maxn], out[maxn], timeDfs;
ll depth[maxn], ans, f[maxn][2];
bool is1[maxn], is2[maxn];
void dfs(int u){
tin[u] = ++timeDfs;
for(auto [v, w]: g[u]){
if(v == up[u][0]) continue;
h[v] = h[u] + 1;
depth[v] = depth[u] + w;
up[v][0] = u;
for(int j = 1; j <= 21; j++) up[v][j] = up[up[v][j - 1]][j - 1];
dfs(v);
}
out[u] = timeDfs;
}
int lca(int u, int v){
if(h[u] != h[v]){
if(h[u] < h[v]) swap(u, v);
int k = h[u] - h[v];
for(int j = 0; (1 << j) <= k; j++){
if((k >> j) & 1){
u = up[u][j];
}
}
}
if(u == v) return u;
int k = __lg(h[u]);
for(int j = k; j >= 0; j--) {
if(up[u][j] != up[v][j]){
u = up[u][j];
v = up[v][j];
}
}
return up[u][0];
}
ll dist(int u, int v){
int k = lca(u, v);
return depth[u] + depth[v] - 2 * depth[k];
}
void dp(int u, int pre){
for(auto [v, w]: adj[u]){
if(v == pre) continue;
dp(v, u);
f[u][0] = min(f[u][0], f[v][0] + w);
f[u][1] = min(f[u][1], f[v][1] + w);
}
if(is1[u]){
f[u][0] = 0;
if(is2[u]){
f[u][1] = 0;
ans = 0;
}
else ans = min(ans, f[u][1]);
}
else if(is2[u]){
f[u][1] = 0;
ans = min(ans, f[u][0]);
}
else ans = min(ans, f[u][0] + f[u][1]);
}
void Init(int N_, int A[], int B[], int D[]){
for(int i = 1; i <= N_ - 1; i++) {
int u = A[i - 1], v = B[i - 1], w = D[i - 1];
u++; v++;
g[u].push_back({v, w});
g[v].push_back({u, w});
}
dfs(1);
for(int i = 1; i <= N_; i++) f[i][0] = f[i][1] = 1e16;
}
long long Query(int S, int X[], int T, int Y[]){
int s = S, t = T;
vector<int> comp1(s), comp2(t);
vector<int> vert;
for(int i = 0; i < s; i++) is1[++X[i]] = 1, vert.push_back(X[i]);
for(int i = 0; i < t; i++) is2[++Y[i]] = 1, vert.push_back(Y[i]);
sort(vert.begin(), vert.end(), [&](int c, int d) {return tin[c] < tin[d];});
int k = vert.size();
for(int i = 0; i < k - 1; i++) vert.push_back(lca(vert[i], vert[i + 1]));
sort(vert.begin(), vert.end(), [&](int c, int d) {return tin[c] < tin[d];});
vert.erase(unique(vert.begin(), vert.end()), vert.end());
stack<int> st;
for(int u: vert) {
while(st.size() && !(tin[st.top()] <= tin[u] && tin[u] <= out[st.top()])) st.pop();
if(st.size()){
int v = st.top();
ll w = dist(u, v);
adj[u].push_back({v, w});
adj[v].push_back({u, w});
}
st.push(u);
}
ans = 1e16;
dp(vert[0], -1);
for(int u: vert){
is1[u] = is2[u] = 0;
f[u][0] = f[u][1] = 1e16;
adj[u].clear();
}
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |