Submission #1289521

#TimeUsernameProblemLanguageResultExecution timeMemory
1289521shidou26LOSTIKS (INOI20_lostiks)C++20
0 / 100
5 ms568 KiB
#include <bits/stdc++.h>
using namespace std;

#ifdef KURUMI
    #include "algo/debug.h"
#endif

#define endl '\n'
#define fi first
#define se second
#define sz(v) (int)v.size()
#define all(v) v.begin(), v.end()
#define filter(v) v.resize(unique(all(v)) - v.begin())
#define dbg(x) "[" #x << " = " << x << "]" 

mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
template<typename T1, typename T2> T2 rand(T1 l, T2 r) {
    return uniform_int_distribution<T2>(l, r)(rng);
}
template<typename T1, typename T2> T2 wrand(T1 l, T2 r, int seed) {
    if(seed == 0) return rand(l, r);
    else return (seed > 0 ? max(rand(l, r), wrand(l, r, seed - 1)) : min(rand(l, r), wrand(l, r, seed + 1)));
}

template<typename T> bool maximize(T &a, T b) {
    if(a < b) {
        a = b;
        return true; 
    }else return false;
}
template<typename T> bool minimize(T &a, T b) {
    if(a > b) {
        a = b;
        return true;
    }else return false;
}

typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, int> pli;
typedef pair<ll, ll> pll;
typedef tuple<int, int, int> tp3;

const int N = 1e6 + 3;
const int M = 20;

int n, m, s, t;
vector<int> useful;
vector<tp3> edge;
vector<int> adj[N];

struct DisjointSet {
	vector<int> lab, save;
	vector<pair<int&, int>> history;

	DisjointSet () {}
	DisjointSet (int n) : lab(n + 3, -1) {}

	int root(int u) {
		return lab[u] < 0 ? u : root(lab[u]);
	}

	bool unite(int u, int v) {
		u = root(u); v = root(v);
		if(u == v) return false;
		if(-lab[u] < -lab[v]) swap(u, v);
		history.emplace_back(lab[u], lab[u]);
		history.emplace_back(lab[v], lab[v]);
		lab[u] += lab[v];
		lab[v] = u;
		return true;
	}

	bool same(int u, int v) {
		return root(u) == root(v);
	}

	void persist() {
		save.push_back(sz(history));
	}

	void rollback() {
		int t = save.back(); save.pop_back();
		while(sz(history) > t) {
			history.back().fi = history.back().se;
			history.pop_back();
		}
	}
} dsu;

void input() {
	cin >> n >> m >> s >> t;
	dsu = DisjointSet(n);

	for(int i = 1; i < n; i++) {
		int u, v, lock; cin >> u >> v >> lock;
		adj[u].push_back(v);
		adj[v].push_back(u);

		if(lock) {
			edge.emplace_back(u, v, lock);
			useful.push_back(u); 
			useful.push_back(v);
		}else dsu.unite(u, v);
	}
}

const int LOG = 20;
const int INF = 0x3f3f3f3f;

int answer = INF;
int h[N], id[N];
int par[N][LOG];

void prepare(int u, int p) {
	for(int j = 1; j < LOG; j++) {
		par[u][j] = par[par[u][j - 1]][j - 1];
	}

	for(int v : adj[u]) {
		if(v == p) continue;
		h[v] = h[u] + 1;
		par[v][0] = u;
		prepare(v, u);
	}
}

int lca(int u, int v) {
	if(h[u] < h[v]) swap(u, v);
	for(int i = LOG - 1; i >= 0; i--) {
		if(h[u] - (1 << i) >= h[v]) {
			u = par[u][i];
		}
	}

	if(u == v) return u;
	for(int i = LOG - 1; i >= 0; i--) {
		if(par[u][i] != par[v][i]) {
			u = par[u][i];
			v = par[v][i];
		}
	}

	return par[u][0];
}

int distance(int u, int v) {
	return h[u] + h[v] - 2 * h[lca(u, v)];
}

void process() {
	prepare(1, -1);
	if(dsu.same(s, t)) return cout << distance(s, t) << endl, void();

    sort(all(useful)); filter(useful);
    int k = sz(useful);
    for(int i = 0; i < k; i++) id[useful[i]] = i;

    vector<vector<int>> dp(1 << m, vector<int>(k + 1, INF));
    for(int i = 0; i < m; i++) {
    	int u, v, lock; tie(u, v, lock) = edge[i];
    	if(dsu.same(s, lock) && dsu.same(lock, u)) dp[1 << i][id[u]] = distance(s, lock) + distance(lock, u);//, cout << dbg(s) << dbg(lock) << dbg(u) << endl;
    	if(dsu.same(s, lock) && dsu.same(lock, v)) dp[1 << i][id[v]] = distance(s, lock) + distance(lock, v);//, cout << dbg(s) << dbg(lock) << dbg(v) << endl;
    }

    for(int mask = 0; mask < (1 << m); mask++) {
    	dsu.persist();

    	// cout << bitset<4>(mask) << endl;

    	// cout << "Before" << endl;
    	// for(int j = 1; j <= n; j++) cout << dsu.root(j) << " \n"[j == n];

    	for(int i = 0; i < m; i++) if(mask >> i & 1) {
    		int u, v; tie(u, v, ignore) = edge[i];
    		dsu.unite(u, v);
    		// cout << "Unite " << u << " " << v << endl;
    	}
    	// cout << "Ending Unite" << endl;

    	for(int i = 0; i < m; i++) if(mask >> i & 1) {
    		int u, v; tie(u, v, ignore) = edge[i];
    		for(int j = 0; j < m; j++) if(!(mask >> j & 1)) {
    			int nu, nv, lock; tie(nu, nv, lock) = edge[j];
    			if(dsu.same(u, lock) && dsu.same(lock, nu)) minimize(dp[mask ^ (1 << j)][id[nu]], dp[mask][id[u]] + distance(u, lock) + distance(lock, nu));
    			if(dsu.same(v, lock) && dsu.same(lock, nu)) minimize(dp[mask ^ (1 << j)][id[nu]], dp[mask][id[v]] + distance(v, lock) + distance(lock, nu));
    			if(dsu.same(u, lock) && dsu.same(lock, nv)) minimize(dp[mask ^ (1 << j)][id[nv]], dp[mask][id[u]] + distance(u, lock) + distance(lock, nv));
    			if(dsu.same(v, lock) && dsu.same(lock, nv)) minimize(dp[mask ^ (1 << j)][id[nv]], dp[mask][id[v]] + distance(v, lock) + distance(lock, nv));
    		}
    	}

    	// cout << "After" << endl;
    	// for(int j = 1; j <= n; j++) cout << dsu.root(j) << " \n"[j == n];

    	for(int i = 0; i < k; i++) {
    		if(dsu.same(useful[i], t)) minimize(answer, dp[mask][i] + distance(useful[i], t));
    		// if(answer == 8) {
    		// 	for(int j = 1; j <= n; j++) cout << dsu.root(j) << " \n"[j == n];
    		// 	cout << dsu.same(1, 4) << endl;
    		// 	cout << useful[i] << " " << bitset<4>(mask) << endl;
    		// 	exit(0);
    		// }
    	}
    	dsu.rollback();
    }

    cout << (answer == INF ? -1 : answer) << endl;
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    #define task "TREEMAZE"
    if(fopen(task".inp", "r")) {
        freopen(task".inp", "r", stdin);
        freopen(task".out", "w", stdout);
    }
    
    int testcase = 1; // cin >> testcase;    
    for(int i = 1; i <= testcase; i++) {
        input();
        process();
    }

    cerr << "Saa, watashtachi no deeto hajimemashou" << endl;
    cerr << "Atarashii kiseki wo koko kara hajimeru shining place nee mou ichido kimi to..." << endl;
    
    return 0;
}

Compilation message (stderr)

Main.cpp: In function 'int main()':
Main.cpp:217:16: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
  217 |         freopen(task".inp", "r", stdin);
      |         ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~
Main.cpp:218:16: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
  218 |         freopen(task".out", "w", stdout);
      |         ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...