This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include<bits/stdc++.h>
using namespace std ;
typedef long long ll ;
// mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
#define MAXN 200007
int n , k ;
int a[ MAXN ] ;
vector < int > v[ MAXN ] ;
vector < int > col[ MAXN ] ;
int tot ;
int used[ MAXN ] , subtree[ MAXN ] , prv[ MAXN ] , mrk[ MAXN ] ;
bool added[ MAXN ] ;
vector < int > clist ;
int ans = MAXN ;
void dfs ( int vertex , int lst , int ori ) {
mrk[ vertex ] = ori ;
++ tot ;
subtree[ vertex ] = 1 ;
for ( auto x : v[ vertex ] ) {
if ( x == lst || used[ x ] == 1 ) { continue ; }
prv[ x ] = vertex ;
dfs ( x , vertex , ori ) ;
subtree[ vertex ] += subtree[ x ] ;
}
}
int get_centroid ( int vertex , int lst ) {
for ( auto x : v[ vertex ] ) {
if ( x == lst || used[ x ] == 1 ) { continue ; }
if ( 2 * subtree[ x ] > tot ) {
return get_centroid ( x , vertex ) ;
}
}
return vertex ;
}
queue < int > q ;
void decompose ( int vertex ) {
tot = 0 , prv[ vertex ] = 0 ;
dfs ( vertex , -1 , vertex ) ;
vertex = get_centroid ( vertex , -1 ) ;
tot = 0 , prv[ vertex ] = 0 ;
dfs ( vertex , -1 , vertex ) ;
bool ok = true ;
for ( auto x : col[ a[ vertex ] ] ) {
if ( mrk[ x ] != vertex ) {
ok = false ;
break ;
}
q.push ( x ) ;
}
clist.push_back ( a[ vertex ] ) ;
added[ a[ vertex ] ] = true ;
while ( q.empty ( ) == false ) {
if ( ok == false ) { break ; }
int x = q.front ( ) ;
q.pop ( ) ;
if ( prv[ x ] >= 1 ) {
if ( added[ a[ prv[ x ] ] ] == false ) {
for ( auto x : col[ a[ prv[ x ] ] ] ) {
if ( mrk[ x ] != vertex ) {
ok = false ;
break ;
}
q.push ( x ) ;
}
added[ a[ prv[ x ] ] ] = true ;
clist.push_back ( a[ prv[ x ] ] ) ;
}
}
}
if ( ok == true ) {
ans = min ( ans , (int)clist.size ( ) - 1 ) ;
}
for ( auto x : clist ) {
added[ x ] = false ;
}
clist.clear ( ) ;
used[ vertex ] = 1 ;
for ( auto x : v[ vertex ] ) {
if ( used[ x ] == 0 ) {
decompose ( x ) ;
}
}
}
void input ( ) {
cin >> n >> k ;
for ( int i = 1 ; i < n ; ++ i ) {
int x , y ; cin >> x >> y ;
v[ x ].push_back ( y ) ;
v[ y ].push_back ( x ) ;
}
for ( int i = 1 ; i <= n ; ++ i ) {
cin >> a[ i ] ;
col[ a[ i ] ].push_back ( i ) ;
}
}
void solve ( ) {
decompose ( 1 ) ;
cout << ans << "\n" ;
}
int main ( ) {
ios_base :: sync_with_stdio ( false ) ;
cin.tie ( NULL ) ;
int t = 1 ;
// cin >> t ;
while ( t -- ) {
input ( ) ;
solve ( ) ;
}
return 0 ;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |