Submission #1250845

#TimeUsernameProblemLanguageResultExecution timeMemory
1250845pastaLOSTIKS (INOI20_lostiks)C++20
100 / 100
1719 ms284400 KiB
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef pair<int, int> pii;

#define pb		push_back
#define lc		(id * 2)
#define rc		(lc + 1)
#define S		second
#define F       first
#define migmig cin.tie(0); cout.tie(0); ios::sync_with_stdio(false)

// #define int		ll
#pragma GCC optimize("Ofast,unroll-loops,inline")   
#pragma GCC target("avx2,bmi2,lzcnt")  

const int maxn = 1e6 + 10;
const int mod = 1e9 + 7;
const int LOG = 21;
const int inf = 1e9 + 10;

int n, m, s, t, dp[LOG][(1 << LOG)], par[maxn][LOG], qmask[maxn], h[maxn], f[maxn], dis[LOG][LOG], key[maxn];
int fs[maxn];
// f[i] = dis(i, t);
vector<pii> G[maxn], edge;
vector<int> e;

void dfs(int v, int p) {
	par[v][0] = p;
	for (int i = 1; i < LOG; i++)
		par[v][i] = par[par[v][i - 1]][i - 1];
	for (auto [u, w] : G[v]) {
		if (u == p) continue;
		if (w >= 0)
			qmask[u] = qmask[v] | (1 << w);
		else
			qmask[u] = qmask[v];
		h[u] = h[v] + 1;
		dfs(u, v);
	}
}

int get_par(int v, int d) {
	int res = v;
	for (int i = 0; i < LOG; i++) {
		if ((1 << i) & d) {
			res = par[res][i];
		}
	}
	return res;
}

int LCA(int v, int u) {
	if (h[v] > h[u])
		swap(v, u);
	u = get_par(u, h[u] - h[v]);
	if (u == v)
		return u;
	for (int i = LOG - 1; i >= 0; i--) {
		if (par[v][i] != par[u][i]) {
			v = par[v][i];
			u = par[u][i];
		}
	}
	return par[v][0];
}


int dist(int v, int u) {
	int lca = LCA(v, u);
	return h[v] + h[u] - 2 * h[lca];
}

int path(int v, int u) {
	return qmask[v] ^ qmask[u];
}

int P[maxn];

signed main() {
	migmig;
	cin >> n >> s >> t;
	int ted = 0;
	for (int i = 0; i < n - 1; i++) {
		int v, u, w;
		cin >> v >> u >> w;
		key[i] = w;
		if (w > 0)
			ted++;
		int x = ted - 1;
		if (w == 0)
			x = -1;
		G[v].pb({u, x});
		G[u].pb({v, x});
		edge.pb({v, u});
		if (w > 0)
			e.pb(i);
	}
	dfs(s, 0);
	for (int i = 0; i < n - 1; i++) {
		int v = edge[i].F, u = edge[i].S;
		if (v == par[u][0])
			P[i] = v;
		else
			P[i] = u;
	}

	for (int i = 1; i <= n; i++)
		f[i] = dist(i, t), fs[i] = dist(i, s);

	for (int i = 0; i < e.size(); i++) {
		for (int j = 0; j < e.size(); j++) {
			dis[i][j] = dist(P[e[i]], key[e[j]]) + dist(key[e[j]], P[e[j]]);
		}
	}
	// swap(s, t);
	m = int(e.size());
	int ans = inf;
	for (int mask = 0; mask < (1 << m); mask++) {
		for (int i = 0; i < m; i++) {
			//dp[i][mask]
			dp[i][mask] = inf;
			if ((path(t, P[e[i]]) & mask) == 0) {
				dp[i][mask] = f[P[e[i]]];
				continue;
			}
			for (int bt = 0; bt < m; bt++) {
				if (((1 << bt) & mask) == 0) continue;
				if ((path(key[e[bt]], P[e[bt]]) & mask) == 0 && ((path(P[e[i]], key[e[bt]]) & mask) == 0 ))
					dp[i][mask] = min(dp[i][mask], dis[i][bt] + dp[bt][mask ^ (1 << bt)]);
			}
		}
	}

	for (int i = 0; i < m; i++) {
		int mask = (1 << m) - 1;
		if (path(s, key[e[i]]) == 0 && path(key[e[i]], P[e[i]]) == 0) {
			ans = min(ans, dist(s, key[e[i]]) + dist(key[e[i]], P[e[i]]) + dp[i][mask ^ (1 << i)]);
		}
	}

	if (path(s, t) == 0)
		ans = min(ans, dist(s, t));
	// for (int mask = 0; mask < (1 << m); mask++) {
	// 	for (int i = 0; i < m; i++) {
	// 		//dp[i][mask]
	// 		cout << i << ' ' << mask << ' ' << dp[i][mask] << '\n';
	// 	}
	// }
	if (ans >= inf)
		ans = -1;
	cout << ans << '\n';
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...