#include "factories.h"
#include <cstdio>
#include <algorithm>
#include <vector>
#define z exit(0)
typedef long long ll;
#define F first
#define S second
#define mp make_pair
using namespace std;
using pii = pair<int,int>;
const int N = 5e5 + 5, M = 1e6 + 5;
const ll inf = 1e18;
vector<pii> g[N]; vector<int> vt_g[N];
int col[N], tin[N], tout[N], LOG, m, euler[M], ST[M][21], T[M], V[N], st[N];
ll dp[N];
void dfs(int u, int p){
euler[m++] = u;
for(pii it: g[u]){
int w = it.F, v = it.S;
if(v != p){ dp[v] = dp[u] + w; dfs(v, u); euler[m++] = u;}
}
}
int LCA(int u, int v){
int l = tin[u], r = tin[v];
if(l > r) swap(l, r);
int j = 31 - __builtin_clz(r-l+1);
return T[min(ST[l][j], ST[r-(1<<j)+1][j])];
}
ll d(int u, int v){ return dp[u] + dp[v] - (dp[LCA(u, v)]<<1LL);}
void Init(int n, int A[], int B[], int D[]){
for(int i = 0; i<n; ++i){ g[i].clear();}
for(int i = 0, u, v, w; i<n-1; ++i){
u = A[i]; v = B[i]; w = D[i];
g[u].emplace_back(w, v); g[v].emplace_back(w, u);
}
dp[0] = m = 0; dfs(0, 0); LOG = 32 - __builtin_clz(m);
for(int i = m-1; i>=0; --i) tin[euler[i]] = tout[euler[i]] = i, T[i] = euler[i];
for(int i = 0; i<m; ++i) tout[euler[i]] = i, ST[i][0] = tin[euler[i]];
for(int j = 1; j<LOG; ++j){
for(int i = 0; i+(1<<j)-1<m; ++i){
ST[i][j] = min(ST[i][j-1], ST[i+(1<<(j-1))][j-1]);
}
}
}
bool anc(int u, int v){ return tin[u] <= tin[v] && tout[v] <= tout[u];}
int ded[N], sz[N], ct; ll ans, mn[3];
int fsz(int u){ sz[u] = 1; for(int v: vt_g[u]) if(!ded[v]) sz[u] += fsz(v); return sz[u];}
int fcen(int u, int tsz){ for(int v: vt_g[u]) if(sz[v] > tsz/2) return fcen(v, tsz); return u;}
void efs(int u){
if(col[u] && col[ct] && col[u] != col[ct]) ans = min(ans, d(ct, u));
if(col[u] == 1 && mn[2] < inf) ans = min(ans, d(ct, u) + mn[2]);
if(col[u] == 2 && mn[1] < inf) ans = min(ans, d(ct, u) + mn[1]);
for(int v: vt_g[u]) efs(v);
}
void ffs(int u){
mn[col[u]] = min(mn[col[u]], d(ct, u));
for(int v: vt_g[u]) ffs(v);
}
void cd(int u){
ded[ct = u = fcen(u, fsz(u))] = true;
for(int v: vt_g[u]) if(!ded[v]) efs(v), ffs(v);
mn[0] = mn[1] = mn[2] = inf;
for(int v: vt_g[u]) if(!ded[v]) cd(v);
}
ll Query(int s, int X[], int t, int Y[]){
for(int i = m = 0; i<s; ++i) col[X[i]] = 1, V[m++] = X[i];
for(int i = 0; i<t; ++i) col[Y[i]] = 2, V[m++] = Y[i];
auto cmp = [&](int u, int v) { return tin[u] < tin[v];};
sort(V, V+m, cmp);
int mm = m;
for(int i = 1; i<m; ++i) V[mm++] = LCA(V[i-1], V[i]);
m = mm; sort(V, V+m, cmp); m = unique(V, V+m) - V;
//
int sz = 0;
st[sz++] = V[0];
for(int i = 1; i<m; ++i){
for(; sz > 1 && !anc(st[sz-1], V[i]); --sz) vt_g[st[sz-2]].push_back(st[sz-1]);
st[sz++] = V[i];
}
for(; sz > 1; --sz) vt_g[st[sz-2]].push_back(st[sz-1]);
ans = mn[0] = mn[1] = mn[2] = inf; cd(V[0]);
for(int i = 0; i<m; ++i) col[V[i]] = ded[V[i]] = 0, vt_g[V[i]].clear();
return ans;
}
/*
signed main(){
int n, q; scanf("%d %d", &n, &q);
int A[n], B[n], D[n], X[n], Y[n];
for(int i = 0, u, v, w; i<n-1; ++i){
scanf("%d %d %d", &u, &v, &w);
A[i] = u; B[i] = v; D[i] = w;
}
Init(n, A, B, D);
for(int i = 0; i<q; ++i){
int s, t; scanf("%d %d", &s, &t);
for(int j = 0; j<s; ++j) scanf("%d", X+j);
for(int j = 0; j<t; ++j) scanf("%d", Y+j);
printf("%lld\n", Query(s, X, t, Y));
}
}
*/
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
23 ms |
24412 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
11 ms |
24156 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
23 ms |
24412 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |