Submission #369521

#TimeUsernameProblemLanguageResultExecution timeMemory
369521soroushFactories (JOI14_factories)C++17
0 / 100
21 ms24172 KiB
#include "factories.h"
#include <bits/stdc++.h>

using namespace std;


typedef long long ll;
typedef pair < ll , ll > pii;
const ll maxn = 5e5 + 100;

#define pb push_back

ll st[maxn] , ft[maxn] , ti = 0;
vector < pii > adj[maxn];
ll h[maxn];

const ll inf = 1e18;

ll red[maxn] , blue[maxn];

void dfs(ll v = 1 , ll p = 0){
	st[v] = ++ti;
	for(auto [u , w] : adj[v])if(u ^ p)
		h[u] = h[v] + w , dfs(u , v);
	ft[v] = ti;
}

ll sz[maxn] , H[maxn] , head[maxn] , par[maxn] , mx[maxn];

void szdfs(ll v = 1){
	sz[v] = 1;
	for(auto [u , w] : adj[v])if(u ^ par[v]){
		par[u] = v , H[u] = H[v] + 1 , szdfs(u) , sz[v] += sz[u];
		if(sz[u] > sz[mx[v]])mx[v] = u;
	}
}

void chaindfs(ll v = 1){
	if(mx[v])head[mx[v]] = head[v] , chaindfs(mx[v]);
	for(auto [u , w] : adj[v]) if(u ^ par[v] and u ^ mx[v])
		head[u] = u , chaindfs(u);
	
}

ll lca(ll u , ll v){
	while(head[u] != head[v]){
		if(h[head[u]] < h[head[v]])swap(u , v);
		u = par[head[u]];
	}
	if(h[u] < h[v])swap(u , v);
	return(v);
}

void Init(int N, int A[], int B[], int D[]) {
	for(ll i = 0 ; i < N - 1 ; i ++)
		adj[A[i]+1].pb({B[i]+1 , D[i]}),
		adj[B[i]+1].pb({A[i]+1 , D[i]});
	dfs();
	szdfs();
	head[1] = 1;
	chaindfs();
}

ll stk[maxn] , ptr = 0;

vector < ll > Adj[maxn];
bool b[maxn] , r[maxn];


void solve(ll v){
	red[v] = blue[v] = inf;
	if(b[v])blue[v] = 0;
	if(r[v])red[v] = 0;
	for(auto u : Adj[v]){
		solve(u);
		blue[v] = min(blue[v] , blue[u] + h[u] - h[v]);
		red[v] = min(red[v] , red[u] + h[u] - h[v]);
	}
}


ll Query(int S, int X[], int T, int Y[]) {
	ll ans = inf;
	vector < ll > vec;
	for(ll i = 0 ; i < S ; i ++)vec.pb(X[i]+1) , b[X[i] + 1] = 1;
	for(ll i = 0 ; i < T ; i ++)vec.pb(Y[i]+1) , r[Y[i] + 1] = 1;
	sort(vec.begin() , vec.end() , [](ll a , ll b){return st[a] < st[b];});
	for(ll i = 1 ; i < S + T ; i ++)vec.pb(lca(vec[i] , vec[i - 1]));
	sort(vec.begin() , vec.end() , [](ll a , ll b){return st[a] < st[b];});
	vec.resize(unique(vec.begin() , vec.end()) - vec.begin());
	ptr = 0;
	stk[++ptr] = vec[0];
	for(ll i = 1 ; i < (ll)vec.size() ; i++){
		ll v = vec[i];
		while(st[stk[ptr]] > st[v] or ft[v] > ft[stk[ptr]])ptr--;
		Adj[stk[ptr]].pb(v);
		stk[++ptr] = v;
		if(i == 6)exit(0);
	}
	solve(vec[0]);
    for(ll i : vec)Adj[i].clear() , ans = min(ans , red[i] + blue[i]) , b[i] = r[i] = 0;
	return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...