#include <bits/stdc++.h>
#include "factories.h"
#define ll long long
#define pb push_back
#define eb emplace_back
#define pu push
#define ins insert
#define fi first
#define se second
#define all(a) a.begin(),a.end()
#define fastio ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define mpair make_pair
using namespace std;
//mt19937 mt(chrono::steady_clock::now().time_since_epoch().count());
typedef pair<int, ll> pii;
const int mod = 2147483647;
const ll inf = 1e18 + 7;
const int N = 5e5 + 5;
int tin[N], tout[N], t = 0;
vector<pii> adj[N];
int up[N][20];
ll d[N];
void dfs(int u, int par){
tin[u] = ++t;
up[u][0] = par;
for(int i = 1; i <= 18; i++){
if(up[u][i - 1] == -1) up[u][i] = -1;
else up[u][i] = up[up[u][i - 1]][i - 1];
}
for(pii v : adj[u]){
if(v.fi == par) continue;
d[v.fi] = d[u] + v.se;
dfs(v.fi, u);
}
tout[u] = t;
}
bool inside(int u, int v){
return tin[u] <= tin[v] && tout[v] <= tout[u];
}
int lca(int u, int v){
if(inside(u, v)) return u;
if(inside(v, u)) return v;
for(int i = 18; i >= 0; i--){
if(up[u][i] != -1 && !inside(up[u][i], v)){
u = up[u][i];
}
}
return up[u][0];
}
ll dist(int u, int v){
return d[u] + d[v] - 2 * d[lca(u, v)];
}
vector<pii> edge[N];
ll dp1[N], dp2[N];
void solve(int u, int par, int state){
for(pii v : edge[u]){
solve(v.fi, u, state);
if(state == 1) dp1[u] = min(dp1[u], dp1[v.fi] + v.se);
else dp2[u] = min(dp2[u], dp2[v.fi] + v.se);
}
}
bool comp(int u, int v){
return tin[u] < tin[v];
}
stack<int> st;
ll Query(int S, int x[], int T, int y[]){
int sz = S + T;
vector<int> sorted(2 * sz - 1);
for(int i = 0; i < S; i++){
sorted[i] = x[i];
dp1[x[i]] = 0;
}
for(int i = 0; i < T; i++){
sorted[S + i] = y[i];
dp2[y[i]] = 0;
}
sort(sorted.begin(), sorted.begin() + sz, comp);
int nwsize = 2 * sz - 1;
for(int i = 0; i < sz - 1; i++){
int v = lca(sorted[i], sorted[i + 1]);
sorted[i + sz] = v;
}
sort(sorted.begin(), sorted.end(), comp);
sorted.erase(unique(sorted.begin(), sorted.end()), sorted.end());
nwsize = sorted.size();
while(!st.empty()) st.pop();
st.pu(sorted[0]);
for(int i = 1; i < nwsize; i++){
while(!st.empty() && !inside(st.top(), sorted[i])){
st.pop();
}
edge[st.top()].pb(mpair(sorted[i], dist(st.top(), sorted[i])));
//cout << dist(st.top(), sorted[i]) << '\n';
st.pu(sorted[i]);
}
solve(sorted[0], -1, 1);
solve(sorted[0], -1, 2);
ll ans = inf;
for(int i = 0; i < nwsize; i++){
ans = min(ans, dp1[sorted[i]] + dp2[sorted[i]]);
}
for(int i = 0; i < nwsize; i++){
dp1[sorted[i]] = inf; dp2[sorted[i]] = inf;
edge[sorted[i]].clear();
}
return ans;
}
void Init(int n, int a[], int b[], int w[]){
for(int i = 0; i <= n; i++){
dp1[i] = inf;
dp2[i] = inf;
}
for(int i = 0; i < n - 1; i++){
adj[a[i]].pb(mpair(b[i], w[i]));
adj[b[i]].pb(mpair(a[i], w[i]));
}
d[0] = 0;
dfs(0, -1);
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |