답안 #138056

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
138056 2019-07-29T07:39:27 Z anayk Mergers (JOI19_mergers) C++14
0 / 100
295 ms 35568 KB
#include <iostream>
#include <vector>
#include <queue>
#include <algorithm>

#define MAXN 500005
#define LOGM 20

int par[LOGM][MAXN];
int p[MAXN];
int lev[MAXN];
int mark[MAXN];
int type[MAXN];
int tar[MAXN];
std::vector<int> Adj[MAXN];
std::vector<int> coll[MAXN];
std::pair<int, int> ord[MAXN];
std::queue<int> q;
int deg[MAXN];

void dfs(int u, int p = 0)
{
	lev[u] = 0;
	if(p)
	{
		par[0][u] = p;
		lev[u] = lev[p]+1;
	}
	
	for(int i = 1; i < LOGM; i++)
	{
		if(par[i-1][u])
			par[i][u] = par[i-1][par[i-1][u]];
	}
	
	for(int v : Adj[u])
	{
		if(v == p)
			continue;
		
		dfs(v, u);
	}
}

int lca(int u, int v)
{
	if(lev[u] > lev[v])
		u ^= v ^= u ^= v;
	
	for(int i = LOGM-1; i >= 0; i--)
	{
		if(lev[v]-lev[u] >= (1 << i))
			v = par[i][v];
	}
	
	if(u == v)
		return u;
	
	for(int i = LOGM-1; i >= 0; i--)
	{
		if(par[i][u] != par[i][v])
		{
			u = par[i][u];
			v = par[i][v];
		}
	}
	
	return par[0][u];
}

int find(int u)
{
	if(p[u] < 0)
		return u;
	else
		return p[u] = find(p[u]);
}

int main()
{
	int n, k;
	std::cin >> n >> k;
	
	for(int i = 1; i < n; i++)
	{
		int a, b;
		std::cin >> a >> b;
		Adj[a].push_back(b);
		Adj[b].push_back(a);
	}
	
	dfs(1);
	
	for(int i = 1; i <= n; i++)
	{
		int t;
		std::cin >> t;
		coll[t].push_back(i);
		type[i] = t;
		
		if(tar[t] == 0)
			tar[t] = i;
		else
			tar[t] = lca(tar[t], i);
	}
	
	for(int i = 0; i < k; i++)
	{
		ord[i] = {lev[tar[i+1]], i+1};
	}
	
	std::sort(ord, ord+k);
	
	for(int i = 0; i < k; i++)
	{
		int t = ord[i].second;
		
		if(mark[t])
			continue;
		
		q.push(t);
		mark[t] = i+1;
		while(!q.empty())
		{
			int c = q.front(); q.pop();
			for(int j : coll[c])
			{
				if(j == tar[t])
					continue;
				
				int v = par[0][j];
				if(!mark[v])
				{
					mark[v] = i+1;
					q.push(v);
				}
			}
		}
	}
	
	int ans = 1;
	for(int i = 1; i <= n; i++)
	{
		for(int j : Adj[i])
			if(mark[type[i]] != mark[type[j]])
				deg[mark[type[i]]]++;
	}
	
	for(int i = 1; i <= k; i++)
		if(deg[i] == 1)
			ans++;
	
	std::cout << ans/2;
	
	return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 24 ms 23900 KB Output is correct
2 Incorrect 24 ms 23928 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 24 ms 23900 KB Output is correct
2 Incorrect 24 ms 23928 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 24 ms 23900 KB Output is correct
2 Incorrect 24 ms 23928 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 202 ms 30880 KB Output is correct
2 Correct 295 ms 35568 KB Output is correct
3 Incorrect 29 ms 24312 KB Output isn't correct
4 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 24 ms 23900 KB Output is correct
2 Incorrect 24 ms 23928 KB Output isn't correct
3 Halted 0 ms 0 KB -