Submission #497629

#TimeUsernameProblemLanguageResultExecution timeMemory
497629zhougzRoadside Advertisements (NOI17_roadsideadverts)C++17
100 / 100
87 ms10924 KiB
/**
 *    author: zhougz
 *    created: 23/12/2021 19:37:06
**/
#include <bits/stdc++.h>

using namespace std;

int n, k;
const int MAXN = 50'050, MAXK = log2(MAXN) + 1;
vector<pair<int, int>> v[MAXN];
int anc[MAXN][MAXK], dist[MAXN], dep[MAXN], par[20];
vector<tuple<int, int, int>> edges;

void clear() {
	iota(par, par + 20, 0);
	edges.clear();
}

int find(int x) {
	return x == par[x] ? x : par[x] = find(par[x]);
}

void unite(int x, int y) {
	par[find(y)] = find(x);
}

void dfs(int x, int par) {
	anc[x][0] = par;
	for (int i = 1; i <= k; i++) {
		int half_par = anc[x][i - 1];
		if (half_par == -1) {
			break;
		}
		anc[x][i] = anc[half_par][i - 1];
	}
	for (const auto &[z, w] : v[x]) {
		if (z != par) {
			dist[z] = dist[x] + w;
			dep[z] = dep[x] + 1;
			dfs(z, x);
		}
	}
}

int lca(int a, int b) {
	if (dep[a] > dep[b]) {
		swap(a, b);
	}
	int bal = dep[b] - dep[a];
	for (int i = 0; i <= k; i++) {
		if (bal & (1 << i)) {
			b = anc[b][i];
		}
	}
	assert(dep[a] == dep[b]);
	if (a == b) {
		return a;
	}
	for (int i = k; i >= 0; i--) {
		if (anc[a][i] == anc[b][i]) {
			continue;
		}
		a = anc[a][i];
		b = anc[b][i];
	}
	assert(anc[a][0] == anc[b][0]);
	return anc[a][0];
}

int main()
{
	ios::sync_with_stdio(false);
	cin.tie(0);
	cin >> n;
	k = log2(n);
	for (int i = 0, a, b, w; i < n - 1; i++) {
		cin >> a >> b >> w;
		v[a].emplace_back(b, w);
		v[b].emplace_back(a, w);
	}
	dfs(0, -1);
	int q;
	cin >> q;
	while (q--) {
		vector<int> nodes;
		for (int i = 0, x; i < 5; i++) {
			cin >> x;
			nodes.push_back(x);
		}
		for (int i = 0; i < 5; i++) {
			for (int j = i + 1; j < 5; j++) {
				nodes.push_back(lca(nodes[i], nodes[j]));
			}
		}
		sort(nodes.begin(), nodes.end());
		nodes.erase(unique(nodes.begin(), nodes.end()), nodes.end());
		clear();
		for (int i = 0; i < (int)nodes.size(); i++) {
			for (int j = i + 1; j < (int)nodes.size(); j++) {
				edges.emplace_back(dist[nodes[i]] + dist[nodes[j]] - 2 * dist[lca(nodes[i], nodes[j])], i, j);
			}
		}
		sort(edges.begin(), edges.end());
		int res = 0;
		for (const auto &[w, a, b] : edges) {
			if (find(a) != find(b)) {
				unite(a, b);
				res += w;
			}
		}
		cout << res << '\n';
	}
	return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...