#include "factories.h"
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,j=0;
vector<pair<int,int>> adj[500001];
pair<int,ll> lift[500001][20];
int depth[500001];
int sz[500001];
bool marked[500001]={0};
int decomp_parent[500001];
pair<ll,int> dp[500001];
void dfs_dist(int v, int p=-1){
dp[v]={-1,-1};
if(p==-1){
lift[v][0]={v,0};
depth[v]=0;
}
for(auto [w,d] : adj[v]){
if(w!=p){
depth[w]=depth[v]+1;
lift[w][0]={v,d};
dfs_dist(w,v);
}
}
}
ll dist(int a, int b){
if(depth[a]<depth[b])swap(a,b);
ll ans=0;
if(depth[a]!=depth[b]){
for(int j=19;j>=0;j--){
if(depth[lift[a][j].first]>depth[b]){
a=lift[a][j].first;
ans+=lift[a][j].second;
}
}
a=lift[a][0].first;
ans+=lift[a][0].second;
}
if(a==b)return ans;
for(int j=19;j>=0;j--){
if(lift[a][j].first!=lift[b][j].first){
a=lift[a][j].first;
b=lift[b][j].first;
ans+=lift[a][j].second;
ans+=lift[b][j].second;
}
}
ans+=lift[a][0].second;
ans+=lift[b][0].second;
return ans;
}
int dfs_sizes(int v, int p=-1){
sz[v]=1;
for(auto [w,d] : adj[v]){
if(w!=p&&!marked[w]){
dfs_sizes(w);
sz[v]+=sz[w];
}
}
return sz[v];
}
int find_centroid(int v, int subsz, int p=-1){
for(auto [w,d] : adj[v]){
if(w!=p&&!marked[w]&&sz[w]>subsz/2){
return find_centroid(w,subsz,v);
}
}
return v;
}
void centroid_decomposition(int v, int p=-1){
v=find_centroid(v,dfs_sizes(v));
decomp_parent[v]=p;
marked[v]=1;
for(auto [w,d] : adj[v]){
if(!marked[w]){
find_centroid(w,sz[w],v);
}
}
}
void Init(int N, int A[], int B[], int D[]) {
n=N;
for(int i=0;i<n;i++){
adj[A[i]].push_back({B[i],D[i]});
adj[B[i]].push_back({A[i],D[i]});
}
dfs_dist(1);
for(int j=1;j<20;j++){
for(int i=0;i<n;i++){
lift[i][j]={lift[lift[i][j-1].first][j-1].first, lift[lift[i][j-1].first][j-1].second+lift[i][j-1].second};
}
}
centroid_decomposition(1);
}
ll Query(int S, int X[], int T, int Y[]) {
ll ans=LLONG_MAX;
for(int i=0;i<S;i++){
int v=X[i];
int start=v;
while(v!=-1){
if(dp[v].second<j)dp[v]={LLONG_MAX,j};
dp[v].first=min(dp[v].first,dist(v,start));
v=decomp_parent[v];
}
}
for(int i=0;i<T;i++){
int v=Y[i];
int start=v;
while(v!=-1){
if(dp[v].second<j)dp[v]={LLONG_MAX,j};
ans=min(ans,dp[v].first+dist(v,start));
v=decomp_parent[v];
}
}
j++;
return ans;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
269 ms |
524288 KB |
Execution killed with signal 9 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
268 ms |
524288 KB |
Execution killed with signal 9 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
269 ms |
524288 KB |
Execution killed with signal 9 |
2 |
Halted |
0 ms |
0 KB |
- |