Submission #570302

#TimeUsernameProblemLanguageResultExecution timeMemory
570302VirvMergers (JOI19_mergers)C++17
100 / 100
1315 ms102360 KiB
#include <functional>
#include <algorithm>
#include <iostream>
#include <numeric>
#include <set>
#include <vector>

using u32 = uint32_t;

int main() {
	u32 n, k;
	std::cin >> n >> k;

	std::vector<std::vector<u32>> g(n);
	for (u32 i = 0; i < n - 1; ++i) {
		u32 a, b;
		std::cin >> a >> b;
		--a, --b;
		g[a].push_back(b);
		g[b].push_back(a);
	}

	std::vector<std::vector<u32>> l(k);
	for (u32 i = 0; i < n; ++i) {
		u32 s;
		std::cin >> s;
		--s;
		l[s].push_back(i);
	}

	std::vector<u32> d(n);
	std::vector<u32> par(n);

	std::function<void(u32, u32, u32)> const dfs = [&](u32 v, u32 p, u32 l) {
		d[v]   = l;
		par[v] = p;
		for (auto u : g[v])
			if (u != p) dfs(u, v, l + 1);
	};

	dfs(0, 0, 0);

	std::vector<u32> uf(n);
	iota(uf.begin(), uf.end(), 0);

	std::function<u32(u32)> const find = [&](u32 i) {
		return uf[i] == i ? i : uf[i] = find(uf[i]);
	};

	auto const merge = [&](u32 i, u32 j) {
		i = find(i);
		j = find(j);
		if (i == j) return;
		if (d[i] > d[j]) std::swap(i, j);
		uf[j] = i;
	};

	for (auto &m : l) {
		std::set<std::pair<u32, u32>> s;
		for (auto x : m) {
			x = find(x);
			s.emplace(d[x], x);
		}

		while (s.size() > 1) {
			auto [_, v] = *s.rbegin();
			s.erase(prev(s.end()));

			merge(v, par[v]);
			v = find(v);
			s.emplace(d[v], v);
		}
	}

	std::vector<u32> cnt(n);

	for (u32 i = 1; i < n; ++i)
		if (find(i) == i) {
			auto j = find(par[i]);
			cnt[i] += 1;
			cnt[j] += 1;
		}

	auto x = count(cnt.begin(), cnt.end(), 1);
	std::cout << (x + 1) / 2 << '\n';
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...