#include "factories.h"
#include <bits/stdc++.h>
using namespace std;
typedef pair<int, int> pi;
typedef long long ll;
const int N = (int)5e5 + 10;
const int L = 19;
int type[N];
ll res, f[N][2];
vector< pair<ll, int> > virt_aj[N];
void dfs2(int v) {
f[v][0] = f[v][1] = 1LL<<60;
if(type[v] != -1) {
f[v][type[v]] = 0;
}
for(const pair<ll, int>& tmp : virt_aj[v]) {
ll w = tmp.first;
int u = tmp.second;
dfs2(u);
res = min(res,
min(f[u][0] + w + f[v][1], f[u][1] + w + f[v][0]));
for(int t = 2; t--; ) {
f[v][t] = min(f[v][t], f[u][t] + w);
}
}
}
int timer=0, tin[N], tout[N], par[N][L];
ll sum[N];
vector<pi> aj[N];
void dfs(int v, int p=-1) {
tin[v] = timer++;
memset(par[v], -1, sizeof(par[v]));
par[v][0] = p;
for(int i = 0; i + 1 < L && par[v][i] != -1; ++i) {
par[v][i + 1] = par[par[v][i]][i];
}
for(const pi& x : aj[v]) {
int u = x.second;
if(u == p) {
continue;
}
sum[u] = sum[v] + x.first;
dfs(u, v);
}
tout[v] = timer;
}
bool anc(int p, int c) {
return tin[p] <= tin[c] && tin[c] < tout[p];
}
int lca(int u, int v) {
if(anc(u, v)) return u;
if(anc(v, u)) return v;
for(int i = L; i--; ) {
if(par[v][i] != -1 && !anc(par[v][i], u)) {
v = par[v][i];
}
}
return par[v][0];
}
void Init(int N, int A[], int B[], int D[]) {
for(int i = 0; i < N - 1; ++i) {
aj[A[i]].emplace_back(D[i], B[i]);
aj[B[i]].emplace_back(D[i], A[i]);
}
dfs(0);
}
long long Query(int S, int X[], int T, int Y[]) {
vector<pi> ord;
for(int i = 0; i < S; ++i)
ord.emplace_back(tin[X[i]], X[i]);
for(int i = 0; i < T; ++i)
ord.emplace_back(tin[Y[i]], Y[i]);
sort(ord.begin(), ord.end());
for(int i = 0, sz = ord.size(); i + 1 < sz; ++i) {
int x = lca(ord[i].second, ord[i + 1].second);
ord.emplace_back(tin[x], x);
}
sort(ord.begin(), ord.end());
ord.erase(unique(ord.begin(), ord.end()), ord.end());
for(const pi& tmp : ord) {
virt_aj[tmp.second].clear();
type[tmp.second] = -1;
}
for(int i = 0; i < S; ++i) {
type[X[i]] = 0;
}
for(int i = 0; i < T; ++i) {
type[Y[i]] = 1;
}
stack<int> stk;
for(const pi& tmp : ord) {
int c = tmp.second;
while(!stk.empty() && !anc(stk.top(), c)) {
stk.pop();
}
if(!stk.empty()) {
int r = stk.top();
virt_aj[r].emplace_back(sum[c] - sum[r], c);
}
stk.push(c);
}
int root = ord.front().second;
res = 1LL<<60;
dfs2(root);
return res;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |