Submission #1166245

#TimeUsernameProblemLanguageResultExecution timeMemory
1166245MateiKing80Mergers (JOI19_mergers)C++20
100 / 100
1381 ms158408 KiB
#include <bits/stdc++.h>

using namespace std;

const int N = 500005;
vector<int> vec[N], stat[N];
int lift[20][N], marte[N], adanc[N], fin[N];
int p1[N], p2[N], pp1 = 0, pp2 = 0;
vector<int> newVec[N];

void dfs(int nod, int papa)
{
	p1[nod] = ++pp1;
	adanc[nod] = 1 + adanc[papa];
	for (auto i : vec[nod])
		if (i != papa)
			dfs(i, nod);
	lift[0][nod] = papa;
	p2[nod] = ++pp2;
	
}

void dfs3(int nod, int papa)
{
	if (marte[nod] == 0)
		fin[nod] = nod;
	else
		fin[nod] = fin[papa];
	
	for (auto i : vec[nod])
		if (i != papa)
		{
			dfs3(i, nod);
			if (marte[i] == 0)
				newVec[i].push_back(fin[nod]), 
				newVec[fin[nod]].push_back(i);
		}			
}

int lca(int x, int y)
{
	if (adanc[x] < adanc[y])
		swap(x, y);
	for (int pas = 19; pas >= 0; pas --)
		if (adanc[lift[pas][x]] >= adanc[y])
			x = lift[pas][x];
	if (x == y)
		return x;
	for (int pas = 19; pas >= 0; pas --)
		if (lift[pas][x] != lift[pas][y])
			x = lift[pas][x], y = lift[pas][y];
	return lift[0][x];
}

void dfs2(int nod, int papa)
{
	for (auto i : vec[nod])
		if (i != papa)
			dfs2(i, nod),
			marte[nod] += marte[i];
}

int main()
{
	ios_base::sync_with_stdio(false);
	cin.tie(0);
	
	int n, k;
	cin >> n >> k;
	for (int i = 1; i < n; i ++)
	{
		int a, b;
		cin >> a >> b;
		vec[a].push_back(b);
		vec[b].push_back(a);
	}
	dfs(1, 0);
	for (int i = 1; i < 20; i ++)
		for (int j = 1; j <= n; j ++)
			lift[i][j] = lift[i - 1][lift[i - 1][j]];
	for (int i = 1; i <= n; i ++)
	{
		int s;
		cin >> s;
		stat[s].push_back(i);
	}
	for (int i = 1; i <= k; i ++)
	{
		int minn = stat[i][0];
		for (auto j : stat[i])
			if (p1[j] < p1[minn])
				minn = j;
		int maxx = stat[i][0];
		for (auto j : stat[i])
			if (p2[j] > p2[maxx])
				maxx = j;
		for (auto j : stat[i])
		{
			int x = lca(j, minn);
			marte[j] ++;
			marte[minn] ++;
			marte[x] -= 2;
			
			int y = lca(j, maxx);
			marte[j] ++;
			marte[maxx] ++;
			marte[y] -= 2;
		}
	}
	dfs2(1, 0);
	dfs3(1, 0);
	int nrFrunze = 0;
	for (int i = 1; i <= n; i ++)
		if (newVec[i].size() == 1)
			nrFrunze ++;
	cout << (nrFrunze + 1) / 2 << '\n';
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...