Submission #236814

#TimeUsernameProblemLanguageResultExecution timeMemory
236814Knps4422Mergers (JOI19_mergers)C++14
100 / 100
1701 ms141864 KiB
//#pragma optimization_level 3
//#pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math,O3")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#include<bits/stdc++.h>
/*
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/pb_ds/detail/standard_policies.hpp>
using namespace __gnu_pbds;
typedef tree<int,null_type,less<int>,rb_tree_tag,tree_order_statistics_node_update>ordset;
*/
#define fr first
#define sc second
#define vec vector
#define pb push_back
#define pii pair<int, int>
#define forn(x,y) for(int x = 1 ; x <= (int)y ; ++x)
#define all(x) (x).begin(),(x).end()
#define fast cin.tie(0);cout.tie(0);cin.sync_with_stdio(0);cout.sync_with_stdio(0);
 
using namespace std;
 
typedef long long ll;
typedef unsigned int uint;
typedef complex<int> point;
const int nmax = 500005;
const ll linf = 1e18;
const ll mod = 998244353;
const int inf = INT_MAX;
 
int n, k, tt, lvl[nmax];
vec < int > g[nmax];
int col[nmax];
int up[nmax][21];
int parent[2*nmax];
int siz[2*nmax];
int targ[nmax];
int tin[nmax],tout[nmax];
vec < int > cn[2*nmax];
int deg[2*nmax];
int find(int x){
	if(parent[x] == x)return x;
	return parent[x] = find(parent[x]);
}
void merge(int x , int y){
	x = find(x);
	y = find(y);
	if(x == y)return ;
	if(siz[x] < siz[y])swap(x,y);
	parent[y] = x;
	siz[x] += siz[y];
}
bool cmp(int x, int y){
	if(lvl[x] != lvl[y])
	return lvl[x] > lvl[y];
	return x < y;
}


void dfs(int x,int p)
{
	tin[x]=++tt;
    up[x][0]=p;
    lvl[x]=lvl[p]+1;
    for(int i=1;i<20;i++)up[x][i]=up[up[x][i-1]][i-1];
    for(int i=0;i<(int)g[x].size();i++)if(g[x][i]!=p)dfs(g[x][i],x);
	tout[x]=++tt;
}
bool anc(int x , int y){
	return tin[x] <= tin[y] && tout[x] >= tout[y];
}

int lca(int x,int y)
{
    if(anc(x,y))return x;
    if(anc(y,x))return y;
    for(int i=19;i>=0;i--)if(!anc(up[x][i],y))x=up[x][i];
    return up[x][0];
}

void rec(int nod){
	if(targ[nod] < lvl[nod]){
		merge(nod,up[nod][0]);
		targ[up[nod][0]] = min(targ[up[nod][0]],targ[nod]);
	}
}
int main(){
	cin >> n >> k;
	int a, b;
	forn(e,n-1){
		cin >> a >> b;
		g[a].pb(b);
		g[b].pb(a);
	}
	forn(i,n){
		cin >> col[i];
		cn[col[i]+n].pb(i);
		parent[i] = n + col[i];
	}
	forn(i,k){
		siz[n+i] = cn[n + i].size();
		parent[n+i] = n+i;
	}
	dfs(1,1);
	vec < int > vc;
	forn(i,n)vc.pb(i);
	sort(all(vc),cmp);
	forn(i,k){
		int t = cn[i+n][0];
		for(int j : cn[i+n]){
			t = lca(t,j);
		}
		for(int j : cn[n+i]){
			targ[j] = lvl[t];
		}
	}
	for(int i : vc)rec(i);
	forn(i,n){
		if(find(i) != find(up[i][0])){
			deg[find(i)]++;
			deg[find(up[i][0])]++;
		}
	}
	int lef = 0;
	forn(i,k){
		if(deg[i+n] == 1)lef++;
	}
	cout << (lef+1)/2 << '\n';;

}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...