//#include "factories.h"
#include <bits/stdc++.h>
using namespace std;
#define task "main"
#define no "NO"
#define yes "YES"
#define F first
#define S second
#define vec vector
#define _mp make_pair
#define ii pair<int, int>
#define sz(x) (int)x.size()
#define all(x) x.begin(), x.end()
#define evoid(val) return void(std::cout << val)
#define FOR(i, a, b) for(int i = (a); i <= (b); ++i)
#define FOD(i, b, a) for(int i = (b); i >= (a); --i)
template<class X, class Y>
bool minimize(X &x, Y y) {
if (x > y) {
x = y;
return true;
}
return false;
}
const int MAX_N = (int)5e5 + 5;
const long long INF = (long long)1e18;
int numNode, numQuery;
vector<ii> adj[MAX_N];
int sz[MAX_N], par[MAX_N];
bool del[MAX_N];
long long minDist[MAX_N]; // minimum distance from u to its children in centroid tree
struct LCA {
int par[20][MAX_N];
long long dist[MAX_N];
int high[MAX_N];
void dfs(int u, int p) {
for (ii e : adj[u]) {
int v = e.F, w = e.S;
if (v == p) continue;
par[0][v] = u;
dist[v] = dist[u] + w;
high[v] = high[u] + 1;
dfs(v, u);
}
}
void setup() {
dfs(0, -1);
FOR(i, 1, 19) FOR(u, 0, numNode - 1)
par[i][u] = par[i - 1][par[i - 1][u]];
}
int getLCA(int u, int v) {
if (high[u] < high[v]) swap(u, v);
FOD(i, 19, 0) if (high[par[i][u]] >= high[v])
u = par[i][u];
if (u == v) return u;
FOD(i, 19, 0) if (par[i][u] != par[i][v]) {
u = par[i][u];
v = par[i][v];
}
return par[0][u];
}
long long getDist(int u, int v) {
return dist[u] + dist[v] - 2LL * dist[getLCA(u, v)];
}
} lca;
void countChild(int u, int p) {
sz[u] = 1;
for (ii e : adj[u]) {
int v = e.F;
if (v == p || del[v]) continue;
countChild(v, u);
sz[u] += sz[v];
}
}
int centroid(int u, int p, int m) {
for (ii e : adj[u]) {
int v = e.F;
if (del[v] || v == p) continue;
if (sz[v] > m / 2) return centroid(v, u, m);
}
return u;
}
int decompose(int u) {
countChild(u, -1);
int m = sz[u];
u = centroid(u, -1, m);
del[u] = 1;
minDist[u] = INF;
for (ii e : adj[u]) {
int v = e.F;
if (del[v]) continue;
int x = decompose(v);
par[x] = u;
}
return u;
}
void addNode(int u) {
int curPar = u;
while (curPar != -1) {
minimize(minDist[curPar], lca.getDist(curPar, u));
curPar = par[curPar];
}
}
void delNode(int u) {
int curPar = u;
while (curPar != -1) {
minDist[curPar] = INF;
curPar = par[curPar];
}
}
long long calcMinDist(int u) {
int curPar = u;
long long res = INF;
while (curPar != -1) {
minimize(res, minDist[curPar] + lca.getDist(curPar, u));
curPar = par[curPar];
}
return res;
}
void Init(int n, int a[], int b[], int d[]) {
numNode = n;
FOR(i, 0, numNode - 2) {
adj[a[i]].push_back({b[i], d[i]});
adj[b[i]].push_back({a[i], d[i]});
}
int root = decompose(0);
par[root] = -1;
lca.setup();
}
long long Query(int s, int x[], int t, int y[]) {
FOR(i, 0, s - 1) addNode(x[i]);
long long ans = INF;
FOR(i, 0, t - 1) minimize(ans, calcMinDist(y[i]));
FOR(i, 0, s - 1) delNode(x[i]);
return ans;
}
/* Lak lu theo dieu nhac!!!! */
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |