Submission #1224130

#TimeUsernameProblemLanguageResultExecution timeMemory
1224130colossal_pepeFactories (JOI14_factories)C++20
33 / 100
8084 ms511652 KiB
#include "factories.h"
#include <bits/stdc++.h>
using namespace std;

using ll = long long;

const ll INF = 1e18;

int n;
vector<vector<pair<int, ll>>> t;
vector<int> ct_par;
vector<unordered_map<int, ll>> dist;

int decompose(int s, vector<int> &subt_n, vector<bool> &done) {
	auto dfs1 = [s, &done, &subt_n](const auto &self, int u, int par) -> void {
		subt_n[u] = 1;
		for (auto [v, d] : t[u]) {
			if (done[v] or v == par) continue;
			self(self, v, u);
			subt_n[u] += subt_n[v];
		}
		// if (s == 4) cerr << "HELLO " << u << ' ' << subt_n[u] << endl;
	};
	dfs1(dfs1, s, -1);
	int cn = subt_n[s];

	auto dfs2 = [cn, &done, &subt_n](const auto &self, int u, int par) -> int {
		for (auto [v, d] : t[u]) {
			if (done[v] or v == par) continue;
			if (subt_n[v] * 2 > cn) return self(self, v, u);
		}
		return u;
	};
	int c = dfs2(dfs2, s, -1);
	// cerr << "bruh " << c << ' ' << cn << endl;

	done[c] = 1;
	dist[c][c] = 0;
	auto dfs3 = [c, &done](const auto &self, int u, int par) -> void {
		for (auto [v, d] : t[u]) {
			if (done[v] or v == par) continue;
			dist[c][v] = dist[c][u] + d;
			self(self, v, u);
		}
	};
	dfs3(dfs3, c, -1);

	for (auto [v, d] : t[c]) {
		if (done[v]) continue;
		int c_nxt = decompose(v, subt_n, done);
		ct_par[c_nxt] = c;
	}
	return c;
}

void Init(int N, int A[], int B[], int D[]) {
	n = N;
	t.resize(n);
	for (int i = 0; i < n - 1; i++) {
		t[A[i]].push_back(make_pair(B[i], D[i]));
		t[B[i]].push_back(make_pair(A[i], D[i]));
	}
	ct_par.resize(n, -1);
	dist.resize(n);
	vector<int> subt_n(n, 0);
	vector<bool> done(n, 0);
	decompose(0, subt_n, done);
}

ll Query(int S, int X[], int T, int Y[]) {
	unordered_map<int, ll> min_dist;
	for (int i = 0; i < T; i++) {
		int u = Y[i];
		while (u != -1) {
			// cerr << "what " << u << endl;
			min_dist[u] = min((min_dist.find(u) == min_dist.end() ? INF : min_dist[u]), dist[u][Y[i]]);
			u = ct_par[u];
		}
	}
	ll ret = INF;
	for (int i = 0; i < S; i++) {
		int u = X[i];
		while (u != -1) {
			// cerr << "how " << u << endl;
			if (min_dist.find(u) != min_dist.end()) ret = min(ret, dist[u][X[i]] + min_dist[u]);
			u = ct_par[u];
		}
	}
	return ret;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...