제출 #635450

#제출 시각아이디문제언어결과실행 시간메모리
635450S2speedLOSTIKS (INOI20_lostiks)C++17
100 / 100
1545 ms247468 KiB
#include<bits/stdc++.h>
 
using namespace std;
 
#pragma GCC optimize ("Ofast")
 
#define sze(x) (int)(x.size())
typedef long long ll;
typedef pair<ll , ll> pll;
typedef pair<int , int> pii;
typedef pair<pii , int> piii;
 
const ll maxn = (1 << 20) + 17 , inf = 2e8;
 
int n , s , t;
pii adj[maxn << 1];
vector<piii> ed;
int dg[maxn] , st[maxn] , ft[maxn] , x[maxn];
int dis[2][maxn];
int bfs[maxn] , sz = 0;
int g[maxn] , k[22] , d[22] , e[22] , f[22][22] , dp[maxn][22] , ind[22];
int jad[maxn][20];

void rBFS(int r , int h){
	sz = 0;
	dis[h][r] = 0;
	g[r] = 0;
	bfs[sz++] = r;
	int x = 0;
	while(x < sz){
		int v = bfs[x++];
		for(int e = st[v] ; e < ft[v] ; e++){
			pii p = adj[e];
			int i = p.first , t = p.second;
			if(dis[h][i] < dis[h][v] + 1) continue;
			jad[i][0] = v;
			for(int j = 1 ; j < 20 ; j++){
				if(jad[i][j - 1] == -1) continue;
				jad[i][j] = jad[jad[i][j - 1]][j - 1];
			}
			g[i] = g[v];
			if(t != -1){
				g[i] ^= (1 << t);
				d[t] = v;
			}
			dis[h][i] = dis[h][v] + 1;
			bfs[sz++] = i;
		}
	}
	return;
}

int find_jad(int v , int d){
	d = dis[0][v] - d;
	for(int j = 0 ; j < 20 ; j++){
		if(d & (1 << j)) v = jad[v][j];
	}
	return v;
}

int lca(int v , int u){
	if(dis[0][v] > dis[0][u]) swap(v , u);
	u = find_jad(u , dis[0][v]);
	if(v == u) return v;
	for(int j = 19 ; ~j ; j--){
		if(jad[v][j] != jad[u][j]){
			v = jad[v][j];
			u = jad[u][j];
		}
	}
	return jad[v][0];
}

int dist(int v , int u){
	int l = lca(v , u);
	return dis[0][v] + dis[0][u] - (dis[0][l] << 1);
}
 
void BFS(int r , int h){
	sz = 0;
	dis[h][r] = 0;
	bfs[sz++] = r;
	int x = 0;
	while(x < sz){
		int v = bfs[x++];
		for(int e = st[v] ; e < ft[v] ; e++){
			pii p = adj[e];
			int i = p.first;
			if(dis[h][i] < dis[h][v] + 1) continue;
			dis[h][i] = dis[h][v] + 1;
			bfs[sz++] = i;
		}
	}
	return;
}
 
int main(){
	ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);

	memset(jad , -1 , sizeof(jad));
	memset(dp , 63 , sizeof(dp));
	memset(dis , 63 , sizeof(dis));
	cin>>n>>s>>t; s--; t--;
	int m = 0;
	for(int i = 1 ; i < n ; i++){
		int v , u , w;
		cin>>v>>u>>w; v--; u--; w--;
		dg[v]++; dg[u]++;
		if(w != -1){
			k[m] = w;
			ed.push_back({{v , u} , m});
			m++;
		} else {
			ed.push_back({{v , u} , -1});
		}
	}
	x[0] = st[0] = 0; ft[0] = dg[0];
	for(int i = 1 ; i < n ; i++){
		x[i] = st[i] = ft[i - 1]; ft[i] = st[i] + dg[i];
	}
	for(auto p : ed){
		int v = p.first.first , u = p.first.second , w = p.second;
		adj[x[v]++] = {u , w}; adj[x[u]++] = {v , w};
	}
	rBFS(s , 0);
	if(g[t] == 0){
		cout<<dis[0][t]<<'\n';
		return 0;
	}
	BFS(t , 1);
	for(int j = 0 ; j < m ; j++){
		e[j] = g[k[j]] | g[d[j]];
		if(e[j] == 0){
			dp[(1 << j)][j] = dis[0][k[j]] + dist(d[j] , k[j]);
		}
	}
	for(int i = 0 ; i < m ; i++){
		for(int j = 0 ; j < m ; j++){
			f[i][j] = dist(d[i] , k[j]) + dist(d[j] , k[j]);
		}
	}
	int lm = (1 << m);
	for(int mask = 3 ; mask < lm ; mask++){
		int bp = 0;
		for(int j = 0 ; j < m ; j++){
			if(mask & (1 << j)) ind[bp++] = j;
		}
		if(bp == 1) continue;
		for(int jj = 0 ; jj < bp ; jj++){
			int j = ind[jj];
			if((e[j] & mask) != e[j]) continue;
			int h = inf , msk = mask ^ (1 << j);
			for(int ii = 0 ; ii < bp ; ii++){
				int i = ind[ii];
				if(i == j) continue;
				h = min(h , dp[msk][i] + f[i][j]);
			}
			dp[mask][j] = h;
		}
	}
	int ans = inf;
	for(int mask = 1 ; mask < lm ; mask++){
		if((g[t] & mask) != g[t]) continue;
		int h = inf;
		for(int i = 0 ; i < m ; i++){
			if(!(mask & (1 << i))) continue;
//			cout<<mask<<' '<<i<<'\n';
			h = min(h , dp[mask][i] + dis[1][d[i]]);
		}
		ans = min(ans , h);
	}
	cout<<(ans == inf ? -1 : ans)<<'\n';
	return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...