#ifndef FACTORIES_H
#define FACTORIES_H
void Init(int N, int A[], int B[], int D[]);
long long Query(int S, int X[], int T, int Y[]);
#endif
#include "factories.h"
#include <bits/stdc++.h>
using namespace std;
static const int MAXN = 500005;
static const int LG = 20;
int N;
vector<pair<int,int>> g[MAXN];
int f[MAXN][LG];
int depth[MAXN];
int tin[MAXN], tout[MAXN], tmr;
long long dis[MAXN];
void dfs(int u, int p) {
tin[u] = ++tmr;
f[u][0] = p;
for (auto &e : g[u]) {
int v = e.first, w = e.second;
if (v == p) continue;
depth[v] = depth[u] + 1;
dis[v] = dis[u] + w;
dfs(v, u);
}
tout[u] = tmr;
}
int lca(int u, int v) {
if (depth[u] < depth[v]) swap(u, v);
int diff = depth[u] - depth[v];
for (int j = 0; j < LG; j++)
if (diff >> j & 1)
u = f[u][j];
if (u == v) return u;
for (int j = LG-1; j >= 0; j--) {
if (f[u][j] != f[v][j]) {
u = f[u][j];
v = f[v][j];
}
}
return f[u][0];
}
long long distBetween(int u, int v) {
int w = lca(u, v);
return dis[u] + dis[v] - 2*dis[w];
}
void Init(int NN, int A[], int B[], int D[]) {
N = NN;
for (int i = 0; i < N; i++) {
g[i].clear();
dis[i] = 0;
depth[i] = 0;
for (int j = 0; j < LG; j++) f[i][j] = 0;
}
for (int i = 0; i < N-1; i++) {
g[A[i]].push_back({B[i], D[i]});
g[B[i]].push_back({A[i], D[i]});
}
tmr = 0;
dfs(0, 0);
for (int j = 1; j < LG; j++) {
for (int i = 0; i < N; i++) {
f[i][j] = f[ f[i][j-1] ][j-1];
}
}
}
long long Query(int S, int X[], int T, int Y[]) {
vector<int> vs;
vs.reserve(S + T);
for (int i = 0; i < S; i++) vs.push_back(X[i]);
for (int i = 0; i < T; i++) vs.push_back(Y[i]);
sort(vs.begin(), vs.end(), [&](int a, int b){ return tin[a] < tin[b]; });
int m0 = vs.size();
for (int i = 1; i < m0; i++) vs.push_back(lca(vs[i-1], vs[i]));
sort(vs.begin(), vs.end(), [&](int a, int b){ return tin[a] < tin[b]; });
vs.erase(unique(vs.begin(), vs.end()), vs.end());
int M = vs.size();
vector<vector<int>> adj(M);
vector<int> stk;
stk.push_back(vs[0]);
for (int i = 1; i < M; i++) {
while (!(tin[stk.back()] <= tin[vs[i]] && tout[vs[i]] <= tout[stk.back()]))
stk.pop_back();
int u = stk.back(), v = vs[i];
int iu = lower_bound(vs.begin(), vs.end(), u, [&](int a,int b){return tin[a]<tin[b];}) - vs.begin();
int iv = lower_bound(vs.begin(), vs.end(), v, [&](int a,int b){return tin[a]<tin[b];}) - vs.begin();
adj[iu].push_back(iv);
stk.push_back(vs[i]);
}
const long long INF = LLONG_MAX/4;
vector<long long> fa(M, INF), fb(M, INF);
for (int i = 0; i < M; i++) {
if (find(X, X+S, vs[i]) != X+S) fa[i] = 0;
if (find(Y, Y+T, vs[i]) != Y+T) fb[i] = 0;
}
long long ans = INF;
function<void(int)> dfsdp = [&](int u){
for (int v : adj[u]) {
dfsdp(v);
long long d = distBetween(vs[u], vs[v]);
fa[u] = min(fa[u], fa[v] + d);
fb[u] = min(fb[u], fb[v] + d);
}
ans = min(ans, fa[u] + fb[u]);
};
dfsdp(0);
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |