#include "factories.h"
#include <bits/stdc++.h>
using namespace std;
int const MAX=5e5+5;
struct edge{
int nod,len;
};
vector<edge>tree[MAX];
bool dead[MAX];
int centroid_dad[MAX];
int subsize[MAX];
int get_subsize(int nod,int tata){
subsize[nod]=1;
for(auto [fiu,w] : tree[nod])
if(!dead[fiu] && fiu!=tata)
subsize[nod]+=get_subsize(fiu,nod);
return subsize[nod];
}
int find_centroid(int nod,int tata,int total_sz){
for(auto [fiu,w] : tree[nod])
if(!dead[fiu] && fiu!=tata && subsize[fiu]>total_sz/2)
return find_centroid(fiu,nod,total_sz);
return nod;
}
int whole_centroid;
int const LOG=23;
long long dist_cen[MAX][LOG];
void get_distance(int nod,int tata,long long dist,int niv){
dist_cen[nod][niv]=dist;
for(auto [fiu,w] : tree[nod])
if(!dead[fiu] && fiu!=tata)
get_distance(fiu,nod,dist+w,niv);
}
void decompose(int nod,int niv,int last_centroid){
int total_sz=get_subsize(nod,-1);
int centroid=find_centroid(nod,-1,total_sz);
if(whole_centroid==-1)
whole_centroid=centroid;
centroid_dad[centroid]=last_centroid;
dead[centroid]=1;
for(auto [vec,w] : tree[centroid])
if(!dead[vec])
get_distance(vec,centroid,w,niv);
for(auto [vec,w] : tree[centroid])
if(!dead[vec])
decompose(vec,niv+1,centroid);
}
void Init(int N, int A[], int B[], int D[]) {
int i;
for(i=0;i<N-1;++i){
tree[A[i]].push_back({B[i],D[i]});
tree[B[i]].push_back({A[i],D[i]});
}
whole_centroid=-1;
decompose(0,0,-1);
}
vector<int>blue_nodes[MAX];
vector<int>red_nodes[MAX];
vector<int>active_sons[MAX];
bool active[MAX];
void add_node(int nod,vector<int>location[]){
int ancestor=nod;
while(ancestor!=-1){
location[ancestor].push_back(nod);
if(!active[ancestor]){
active[ancestor]=1;
int father=centroid_dad[ancestor];
if(father!=-1)
active_sons[father].push_back(ancestor);
}
ancestor=centroid_dad[ancestor];
}
}
long long const INF=1e18;
void minself(long long& x,long long val){
if(x>val)
x=val;
}
long long find_min_distance(int nod,int niv){
long long dist_blue=INF,dist_red=INF;
long long dist=INF;
for(auto node : blue_nodes[nod])
if(node==nod)
dist_blue=0;
for(auto node : red_nodes[nod])
if(node==nod)
dist_red=0;
if(dist_blue==0 && dist_red==0)
dist=0;
for(auto son : active_sons[nod]){
for(auto node : blue_nodes[son])
minself(dist,dist_cen[node][niv]+dist_red);
for(auto node : red_nodes[son])
minself(dist,dist_cen[node][niv]+dist_blue);
for(auto node : blue_nodes[son])
minself(dist_blue,dist_cen[node][niv]);
for(auto node : red_nodes[son])
minself(dist_red,dist_cen[node][niv]);
}
for(auto son : active_sons[nod])
minself(dist,find_min_distance(son,niv+1));
return dist;
}
void clear_vectors(int nod){
int ancestor=nod;
while(ancestor!=-1){
blue_nodes[ancestor].clear();
red_nodes[ancestor].clear();
active_sons[ancestor].clear();
active[ancestor]=0;
ancestor=centroid_dad[ancestor];
}
}
long long Query(int S, int X[], int T, int Y[]) {
int i;
for(i=0;i<S;++i)
add_node(X[i],blue_nodes);
for(i=0;i<T;++i)
add_node(Y[i],red_nodes);
long long dist=find_min_distance(whole_centroid,0);
for(i=0;i<S;++i)
clear_vectors(X[i]);
for(i=0;i<T;++i)
clear_vectors(Y[i]);
return dist;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |