Submission #369522

#TimeUsernameProblemLanguageResultExecution timeMemory
369522soroushFactories (JOI14_factories)C++17
100 / 100
3180 ms150356 KiB
#include "factories.h"
#include <bits/stdc++.h>

using namespace std;

const int maxn = 5e5 + 100;

typedef long long ll;
typedef pair < int , int > pii;

#define pb push_back

int st[maxn] , ft[maxn] , ti = 0;
vector < pii > adj[maxn];
ll h[maxn];

const ll inf = 1e18;

ll red[maxn] , blue[maxn];

void dfs(int v = 1 , int p = 0){
	st[v] = ++ti;
	for(auto [u , w] : adj[v])if(u ^ p)
		h[u] = h[v] + w , dfs(u , v);
	ft[v] = ti;
}

int sz[maxn] , H[maxn] , head[maxn] , par[maxn] , mx[maxn];

void szdfs(int v = 1){
	sz[v] = 1;
	for(auto [u , w] : adj[v])if(u ^ par[v]){
		par[u] = v , H[u] = H[v] + 1 , szdfs(u) , sz[v] += sz[u];
		if(sz[u] > sz[mx[v]])mx[v] = u;
	}
}

void chaindfs(int v = 1){
	if(mx[v])head[mx[v]] = head[v] , chaindfs(mx[v]);
	for(auto [u , w] : adj[v]) if(u ^ par[v] and u ^ mx[v])
		head[u] = u , chaindfs(u);
	
}

int lca(int u , int v){
	while(head[u] != head[v]){
		if(h[head[u]] < h[head[v]])swap(u , v);
		u = par[head[u]];
	}
	if(h[u] < h[v])swap(u , v);
	return(v);
}

void Init(int N, int A[], int B[], int D[]) {
	for(int i = 0 ; i < N - 1 ; i ++)
		adj[A[i]+1].pb({B[i]+1 , D[i]}),
		adj[B[i]+1].pb({A[i]+1 , D[i]});
	dfs();
	szdfs();
	head[1] = 1;
	chaindfs();
}

int stk[maxn] , ptr = 0;

vector < int > Adj[maxn];
bool b[maxn] , r[maxn];


void solve(int v){
	red[v] = blue[v] = inf;
	if(b[v])blue[v] = 0;
	if(r[v])red[v] = 0;
	for(auto u : Adj[v]){
		solve(u);
		blue[v] = min(blue[v] , blue[u] + h[u] - h[v]);
		red[v] = min(red[v] , red[u] + h[u] - h[v]);
	}
}


ll Query(int S, int X[], int T, int Y[]) {
	ll ans = inf;
	vector < int > vec;
	for(int i = 0 ; i < S ; i ++)vec.pb(X[i]+1) , b[X[i] + 1] = 1;
	for(int i = 0 ; i < T ; i ++)vec.pb(Y[i]+1) , r[Y[i] + 1] = 1;
	sort(vec.begin() , vec.end() , [](int a , int b){return st[a] < st[b];});
	for(int i = 1 ; i < S + T ; i ++)vec.pb(lca(vec[i] , vec[i - 1]));
	sort(vec.begin() , vec.end() , [](int a , int b){return st[a] < st[b];});
	vec.resize(unique(vec.begin() , vec.end()) - vec.begin());
	ptr = 0;
	stk[++ptr] = vec[0];
	for(int i = 1 ; i < (int)vec.size() ; i++){
		int v = vec[i];
		while(st[stk[ptr]] > st[v] or ft[v] > ft[stk[ptr]])ptr--;
		Adj[stk[ptr]].pb(v);
		stk[++ptr] = v;
	}
	solve(vec[0]);
    for(int i : vec)Adj[i].clear() , ans = min(ans , red[i] + blue[i]) , b[i] = r[i] = 0;
	return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...