#include <bits/stdc++.h>
//#include "factories.h"
using namespace std;
using ll = long long;
const int mxN = 5e5+10;
const ll INF = 1e17;
int n;
vector<int> centroids[mxN];
vector<array<int, 2>> adj[mxN];
ll dist[mxN][19];
int sz[mxN];
bool blocked[mxN];
array<ll, 2> mn[mxN];
ll ans;
void dfs(int node, int p) {
sz[node] = 1;
for(auto [it, w] : adj[node]) {
if(it == p || blocked[it]) continue;
dfs(it, node);
sz[node] += sz[it];
}
}
int find_centroid(int node, int p, int n) {
for(auto [it, w] : adj[node]) {
if(it == p || blocked[it]) continue;
if(sz[it] * 2 > n) return find_centroid(it, node, n);
}
return node;
}
void dfs2(int node, int p, int cen, ll d) {
dist[node][centroids[node].size()] = d;
centroids[node].push_back(cen);
for(auto [it, w] : adj[node]) {
if(it == p || blocked[it]) continue;
dfs2(it, node, cen, d+w);
}
}
void decomp(int node) {
dfs(node, -1);
int cen = find_centroid(node, -1, sz[node]);
dfs2(cen, -1, cen, 0);
blocked[cen] = true;
for(auto [it, w] : adj[cen]) {
if(blocked[it]) continue;
decomp(it);
}
}
void Init(int N, int A[], int B[], int D[]) {
n = N;
for(int i = 0; i < n-1; i++) {
adj[A[i]].push_back({B[i], D[i]});
adj[B[i]].push_back({A[i], D[i]});
}
decomp(0);
}
long long Query(int S, int X[], int T, int Y[]) {
ans = INF;
for(int i = 0; i < S; i++) {
int now = X[i];
for(auto it : centroids[now]) {
mn[it] = {INF, INF};
}
}
for(int i = 0; i < T; i++) {
int now = Y[i];
for(auto it : centroids[now]) {
mn[it] = {INF, INF};
}
}
for(int i = 0; i < S; i++) {
int now = X[i];
for(int j = 0; j < centroids[now].size(); j++) {
int it = centroids[now][j];
mn[it][0] = min(mn[it][0], dist[now][j]);
}
}
for(int i = 0; i < T; i++) {
int now = Y[i];
for(int j = 0; j < centroids[now].size(); j++) {
int it = centroids[now][j];
mn[it][1] = min(mn[it][1], dist[now][j]);
}
}
for(int i = 0; i < S; i++) {
int now = X[i];
for(int j = 0; j < centroids[now].size(); j++) {
int it = centroids[now][j];
ans = min(ans, mn[it][0] + mn[it][1]);
}
}
for(int i = 0; i < T; i++) {
int now = Y[i];
for(int j = 0; j < centroids[now].size(); j++) {
int it = centroids[now][j];
ans = min(ans, mn[it][0] + mn[it][1]);
}
}
return ans;
}
/*int main()
{
int A[7] = {0, 1, 2, 2, 4, 1};
int B[7] = {1, 2, 3, 4, 5, 6};
int D[7] = {4, 4, 5, 6, 5, 3};
Init(7, A, B, D);
int C[3] = {2};
int E[3] = {5};
cout << Query(1, C, 1, E);
}*/
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |