#include "factories.h"
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll INF = 1'000'000'000'000'000'000;
const int N = 500'000;
const int M = 20;
int n, q, sz[N + 10], up[N + 10];
bool mark[N + 10];
ll mn[N + 10], h[N + 10], dp[N + 10][M + 10], d[N + 10];
vector<pair<int, ll>> adj[N + 10];
//int A[N + 10], B[N + 10], D[N + 10], X[N + 10], Y[N + 10], S, T;
void dfs(int u = 1, int par = 0) {
dp[u][0] = par;
for (int j = 1; j <= M && dp[u][j - 1]; j++)
dp[u][j] = dp[dp[u][j - 1]][j - 1];
for (auto [v, w]: adj[u])
if (v != par) {
h[v] = h[u] + 1;
d[v] = d[u] + w;
dfs(v, u);
}
}
int LCA(int u, int v) {
if (h[u] < h[v])
swap(u, v);
for (int j = M; j >= 0; j--)
if (h[u] - h[v] >= (1 << j))
u = dp[u][j];
if (u == v)
return u;
for (int j = M; j >= 0; j--)
if (dp[u][j] != dp[v][j]) {
u = dp[u][j];
v = dp[v][j];
}
return dp[u][0];
}
ll getDis(int u, int v) {
int lca = LCA(u, v);
return d[u] - d[lca] + d[v] - d[lca];
}
int calcSz(int u, int par = 0) {
sz[u] = 1;
for (auto [v, w]: adj[u])
if (!mark[v] && v != par)
sz[u] += sz[v];
return sz[u];
}
int calcCentroid(int u) {
calcSz(u);
int res = u;
while (true) {
int ok = res;
for (auto [v, w]: adj[res])
if (!mark[v] && sz[v] < sz[res] && sz[v] > sz[u] / 2)
ok = v;
if (ok == res)
return ok;
res = ok;
}
}
int makeGraph(int u = 1) {
u = calcCentroid(u);
mark[u] = true;
for (auto [v, w]: adj[u])
if (!mark[v])
up[makeGraph(v)] = u;
return u;
}
void Init(int N, int A[], int B[], int D[]) {
n = N;
for (int i = 0; i < n - 1; i++) {
A[i]++;
B[i]++;
adj[A[i]].push_back({B[i], D[i]});
adj[B[i]].push_back({A[i], D[i]});
}
dfs();
makeGraph();
fill(mn + 1, mn + n + 1, INF);
}
long long Query(int S, int X[], int T, int Y[]) {
for (int i = 0; i < S; i++) {
X[i]++;
int tmp = X[i];
while (tmp) {
mn[tmp] = min(mn[tmp], getDis(X[i], tmp));
tmp = up[tmp];
}
}
ll ans = INF;
for (int i = 0; i < T; i++) {
Y[i]++;
int tmp = Y[i];
while (tmp) {
ans = min(ans, mn[tmp] + getDis(Y[i], tmp));
tmp = up[tmp];
}
}
for (int i = 0; i < S; i++) {
int tmp = X[i];
while (tmp) {
mn[tmp] = INF;
tmp = up[tmp];
}
}
return ans;
}
/*
int main() {
cin >> n >> q;
for (int i = 0; i < n - 1; i++)
cin >> A[i] >> B[i] >> D[i];
Init(n);
vector<ll> vec;
while (q--) {
cin >> S >> T;
for (int i = 0; i < S; i++)
cin >> X[i];
for (int i = 0; i < T; i++)
cin >> Y[i];
vec.push_back(Query());
}
for (auto x: vec)
cout << x << '\n';
cout.flush();
return 0;
}*/
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
22 ms |
39512 KB |
Output is correct |
2 |
Correct |
811 ms |
44180 KB |
Output is correct |
3 |
Execution timed out |
8025 ms |
43860 KB |
Time limit exceeded |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
7 ms |
39516 KB |
Output is correct |
2 |
Correct |
1908 ms |
194588 KB |
Output is correct |
3 |
Execution timed out |
8016 ms |
198224 KB |
Time limit exceeded |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
22 ms |
39512 KB |
Output is correct |
2 |
Correct |
811 ms |
44180 KB |
Output is correct |
3 |
Execution timed out |
8025 ms |
43860 KB |
Time limit exceeded |
4 |
Halted |
0 ms |
0 KB |
- |