#include<bits/stdc++.h>
#include "factories.h"
using namespace std;
#define ll long long
#define endl "\n"
const int MOD = 998244353;
const int maxN = 500001;
const int maxlog = 20;
int dp[maxN][maxlog];
ll dists[maxN];
int depths[maxN];
vector<pair<int, ll>> adj[maxN];
bool done[maxN];
int subtree_size[maxN];
int parents[maxN];
ll ans[maxN];
int centr;
int root;
void DFS_lca(int x, int p){
dp[x][0] = p;
for(auto pr : adj[x]){
int node = pr.first;
if(node != p){
depths[node] = depths[x] + 1;
dists[node] = dists[x] + pr.second;
DFS_lca(node, x);
}
}
}
void build_lca(int n){
depths[0] = 0;
DFS_lca(0, -1);
for(int i = 1; i < 20; i++){
for(int j = 0; j < n; j++){
if(dp[j][i - 1] == -1) {dp[j][i] = -1; continue;}
dp[j][i] = dp[dp[j][i - 1]][i - 1];
}
}
}
int kth(int x, int d){
for(int k = 19; k >= 0; k--){
if((1 << k) <= d){
x = dp[x][k];
d -= (1 << k);
}
}
return x;
}
int lca(int a, int b){
if(depths[b] > depths[a]) b = kth(b, depths[b] - depths[a]);
else if(depths[a] > depths[b]) a = kth(a, depths[a] - depths[b]);
if(a == b) return a;
int anc = -1;
for(int k = 19; k >= 0; k--){
if(dp[a][k] == dp[b][k]){
anc = max(0, dp[a][k]);
}else{
a = dp[a][k];
b = dp[b][k];
}
}
return anc;
}
int dist(int a, int b){
int anc = lca(a, b);
return dists[a] + dists[b] - 2 * dists[anc];
}
int find_size(int x, int par){
subtree_size[x] = 1;
for(auto pr : adj[x]){
int node = pr.first;
if(node == par || done[node] == 1) continue;
subtree_size[x] += find_size(node, x);
}
return subtree_size[x];
}
void find_centroid(int x, int par, int siz){
bool good = 1;
for(auto pr : adj[x]){
int node = pr.first;
if(node == par || done[node] == 1) continue;
if(subtree_size[node] > siz / 2){
find_centroid(node, x, siz);
good = 0;
}
}
if(good) centr = x;
}
void build(int n){
queue<pair<int, int>> q;
q.push({0, -1});
while(!q.empty()){
int x = q.front().first; int par = q.front().second; q.pop();
int siz = find_size(x, -1);
centr = -1;
find_centroid(x, -1, siz);
if(par == -1) root = centr;
parents[centr] = par;
done[centr] = 1;
for(auto pr : adj[centr]){
int node = pr.first;
if(!done[node]) q.push({node, centr});
}
}
}
void add(int x){
int next = x;
while(next != -1){
ll d = dist(next, x);
ans[next] = min(ans[next], d);
next = parents[next];
}
}
void rmv(int x){
int next = x;
while(next != -1){
ans[next] = 1e18;
next = parents[next];
}
}
ll query(int x){
int next = x;
ll rep = 1e18;
while(next != -1){
ll d = dist(next, x);
rep = min(rep, ans[next] + d);
next = parents[next];
}
return rep;
}
void Init(int N, int A[], int B[], int D[]){
for(int i = 0; i < N - 1; i++){
adj[A[i]].push_back({B[i], D[i]});
adj[B[i]].push_back({A[i], D[i]});
}
build_lca(N);
build(N);
for(int i = 0; i < N; i++) ans[i] = 1e18;
}
ll Query(int S, int X[], int T, int Y[]){
for(int i = 0; i < S; i++) add(X[i]);
ll retrn = 1e18;
for(int i = 0; i < T; i++) retrn = min(retrn, query(Y[i]));
for(int i = 0; i < S; i++) rmv(X[i]);
return retrn;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
23 ms |
12628 KB |
Output is correct |
2 |
Correct |
719 ms |
30364 KB |
Output is correct |
3 |
Incorrect |
1407 ms |
30576 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
9 ms |
12244 KB |
Output is correct |
2 |
Correct |
2841 ms |
123632 KB |
Output is correct |
3 |
Incorrect |
6371 ms |
125716 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
23 ms |
12628 KB |
Output is correct |
2 |
Correct |
719 ms |
30364 KB |
Output is correct |
3 |
Incorrect |
1407 ms |
30576 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |