Submission #1249161

#TimeUsernameProblemLanguageResultExecution timeMemory
1249161arashmemarCapital City (JOI20_capital_city)C++20
100 / 100
906 ms42816 KiB
#include <bits/stdc++.h>
using namespace std;

const int maxn = 2e5 + 100;

bool mark[maxn];
vector <int> poec[maxn], ne[maxn], nb;
int p[maxn], c[maxn];
bool s[maxn], block[maxn];
int nos[maxn], vu[maxn];

int find(int v, int k)
{
	nos[c[v]] = 0;
	s[c[v]] = 0;
	mark[v] = 1;
	bool ok = 1;
	vu[v] = 1;
	int ret = -1;
	for (auto u : ne[v])
	{
		if (mark[u])
		{
			continue;
		}
		p[u] = v;
		ret = max(ret, find(u, k));
		vu[v] += vu[u];
		ok &= (vu[u] <= k / 2);
	}
	if (ok and vu[v] >= (k + 1) / 2)
	{
		ret = v;
	}
	mark[v] = 0;
	return ret;
}

void dfs(int v, bool op)
{
	mark[v] = 1;
	if (s[c[v]] == 0 and op)
	{
		nos[c[v]]++;
	}
	s[c[v]] = op;
	for (auto u : ne[v])
	{
		if (mark[u])
		{
			continue;
		}
		dfs(u, op);
	}
	mark[v] = 0;
	return ;
}

void check(int v)
{
	if (nos[c[v]] > 1 and block[c[v]] == 0)
	{
		nb.push_back(c[v]);
	}
	mark[v] = 1;
	for (auto u : ne[v])
	{
		if (mark[u])
		{
			continue;
		}
		check(u);
	}
	mark[v] = 0;
	return ;
}

int solve(int v, int sz)
{
	if (sz == 1)
	{
		if (block[c[v]])
		{
			return maxn;
		}
		return 0;
	}
	v = find(v, sz);
	p[v] = 0;
	find(v, 0);

	queue <int> q, qq;
	s[c[v]] = 1;
	int ans = 0;
	if (block[c[v]])
	{
		ans = maxn;
	}
	else
	{
		for (auto o : poec[c[v]])
		{
			q.push(o);
		}
	}
	while (q.size())
	{
		int v = q.front();
		while (v and mark[v] == 0)
		{
			mark[v] = 1;
			qq.push(v);
			qq.push(v);
			if (!s[c[v]])
			{
				if (block[c[v]])
				{
					ans = maxn;
				}
				else
				{
					ans++;
					s[c[v]] = 1;
					for (auto o : poec[c[v]])
					{
						q.push(o);
					}
				}
			}
			v = p[v];
		}
		q.pop();
	}

	while (qq.size())
	{
		int o = qq.front();
		mark[o] = 0;
		s[c[o]] = 0;
		qq.pop();
	}

	mark[v] = 1;

	nos[c[v]] = 2;

	for (auto u : ne[v])
	{
		if (mark[u])
		{
			continue;
		}
		dfs(u, 1);
		dfs(u, 0);
	}

	check(v);
	mark[v] = 1;

	vector <int> tmp = nb;
	nb.clear();
	
	for (auto o : tmp)
	{
		block[o] = 1;
	}

	for (auto u : ne[v])
	{
		if (mark[u])
		{
			continue;
		}
		ans = min(ans, solve(u, vu[u]));
	}

	for (auto o : tmp)
	{
		block[o] = 0;
	}

	return ans;
}

int main()
{
	int n, k;
	cin >> n >> k;
	for (int i = 1;i < n;i++)
	{
		int v, u;
		cin >> v >> u;
		ne[v].push_back(u);
		ne[u].push_back(v);
	}

	for (int i = 1;i <= n;i++)
	{
		cin >> c[i];
		poec[c[i]].push_back(i);
	}
	cout << solve(1, n);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...