이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include<bits/stdc++.h>
using namespace std ;
typedef long long ll ;
// mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
#define MAXN 500007
#define LOG 20
int n , k ;
vector < int > v[ MAXN ] ;
int a[ MAXN ] ;
int lvl[ MAXN ] ;
int LCA[ MAXN ][ LOG ] ;
int st[ MAXN ] , tp ;
vector < int > col[ MAXN ] ;
void init ( int vertex , int prv ) {
st[ vertex ] = ++ tp ;
for ( int i = 1 ; i < LOG ; ++ i ) {
LCA[ vertex ][ i ] = LCA[ LCA[ vertex ][ i - 1 ] ][ i - 1 ] ;
}
for ( auto x : v[ vertex ] ) {
if ( x == prv ) { continue ; }
lvl[ x ] = lvl[ vertex ] + 1 ;
LCA[ x ][ 0 ] = vertex ;
init ( x , vertex ) ;
}
}
int ori[ MAXN ] ;
int get_lca ( int x , int y ) {
for ( int i = LOG - 1 ; i >= 0 ; -- i ) {
if ( lvl[ x ] - ( 1 << i ) >= lvl[ y ] ) {
x = LCA[ x ][ i ] ;
}
if ( lvl[ y ] - ( 1 << i ) >= lvl[ x ] ) {
y = LCA[ y ][ i ] ;
}
}
for ( int i = LOG - 1 ; i >= 0 ; -- i ) {
if ( LCA[ x ][ i ] != LCA[ y ][ i ] ) {
x = LCA[ x ][ i ] ;
y = LCA[ y ][ i ] ;
}
}
if ( x != y ) { x = LCA[ x ][ 0 ] ; }
return x ;
}
int dp[ MAXN ] ;
int bad[ MAXN ] ;
int ans = 0 ;
int dfs ( int vertex , int prv ) {
int ret = lvl[ ori[ a[ vertex ] ] ] ;
for ( auto x : v[ vertex ] ) {
if ( x == prv ) { continue ; }
int aux = dfs ( x , vertex ) ;
ret = min ( ret , aux ) ;
dp[ vertex ] += max ( bad[ x ] , dp[ x ] ) ;
}
if ( ret >= lvl[ vertex ] ) {
bad[ vertex ] = 1 ;
if ( dp[ vertex ] == 0 && vertex != 1 ) {
++ ans ;
}
}
return ret ;
}
void input ( ) {
cin >> n >> k ;
for ( int i = 1 ; i < n ; ++ i ) {
int x , y ;
cin >> x >> y ;
v[ x ].push_back ( y ) ;
v[ y ].push_back ( x ) ;
}
for ( int i = 1 ; i <= n ; ++ i ) {
cin >> a[ i ] ;
col[ a[ i ] ].push_back ( i ) ;
}
}
void solve ( ) {
init ( 1 , -1 ) ;
auto cmp = [ & ] ( int x , int y ) {
return ( st[ x ] < st[ y ] ) ;
};
for ( int i = 1 ; i <= k ; ++ i ) {
sort ( col[ i ].begin ( ) , col[ i ].end ( ) , cmp ) ;
ori[ i ] = get_lca ( col[ i ].front ( ) , col[ i ].back ( ) ) ;
}
dfs ( 1 , -1 ) ;
for ( int i = 2 ; i <= n ; ++ i ) {
if ( dp[ i ] == dp[ 1 ] ) {
ans += bad[ i ] ;
}
}
cout << ( ans + 1 ) / 2 << "\n" ;
}
int main ( ) {
ios_base :: sync_with_stdio ( false ) ;
cin.tie ( NULL ) ;
int t = 1 ;
// cin >> t ;
while ( t -- ) {
input ( ) ;
solve ( ) ;
}
return 0 ;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |