# | Time | Username | Problem | Language | Result | Execution time | Memory |
---|---|---|---|---|---|---|---|
332586 | CaroLinda | Mergers (JOI19_mergers) | C++14 | 72 ms | 32996 KiB |
This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
#define all(x) x.begin(),x.end()
#define sz(x) (int)(x.size() )
#define ll long long
const int MAXN = 5e5+10 ;
using namespace std ;
int n , k ;
int state[MAXN] , deg[MAXN] ;
int dsu[MAXN] , qtdState[MAXN] ;
vector<int> adj[MAXN] ;
map<int,int> *ptr[MAXN] ;
int find(int x)
{
if(x == dsu[x] ) return x ;
return dsu[x] = find(dsu[x] ) ;
}
void join(int x, int y )
{
x = find(x) ;
y = find(y) ;
if(x == y ) return ;
if(rand() % 2 ) swap(x,y) ;
deg[x] += deg[y] ;
dsu[y] = x ;
}
int dfs(int x, int father)
{
int bc = -1 ;
int subBc = 0 ;
int sub = 1 ;
for(auto e : adj[x] )
{
if(e == father ) continue ;
int subChild = dfs(e,x) ;
sub += subChild ;
if(subChild > subBc )
{
subBc = subChild ;
bc = e ;
}
}
if(bc == -1 ) ptr[x] = new map<int,int> ;
else ptr[x] = ptr[bc] ;
if( sz( (*ptr[x]) ) == 0 && bc != -1 )
{
deg[ find(bc) ]++ ;
deg[ find(x) ]++ ;
}
else if(bc != -1 ) join(x, bc) ;
for(auto e : adj[x] )
{
if(e == father || e == bc ) continue ;
if( sz( (*ptr[e] ) ) == 0 )
{
deg[ find(x) ]++ ;
deg[ find(e) ]++ ;
continue ;
}
join(x, e) ;
for(auto p : (*ptr[e] ) )
{
(*ptr[x])[p.first] += p.second ;
if( (*ptr[x])[p.first] == qtdState[p.first] )
(*ptr[x]).erase( (*ptr[x]).find(p.first) ) ;
}
(*ptr[e]).clear() ;
}
(*ptr[x])[ state[x] ]++ ;
if( (*ptr[x])[ state[x] ] == qtdState[ state[x] ] )
(*ptr[x]).erase( (*ptr[x]).find( state[x] ) ) ;
}
int main()
{
scanf("%d %d", &n, &k ) ;
for(int i = 0 , u , v ; i < n-1 ; i++ )
{
scanf("%d %d", &u, &v ) ;
adj[u].push_back(v) ;
adj[v].push_back(u) ;
}
for(int i = 1 ; i <= n ; i++ )
{
scanf("%d", &state[i] ) ;
dsu[i] = i ;
qtdState[ state[i] ]++ ;
}
dfs(1,-1) ;
int qtd = 0 ;
for(int i = 1 ; i <= n ; i++ )
if( find(i) == i && deg[i] == 1 ) qtd++ ;
printf("%d\n", (qtd+1)/2 ) ;
}
Compilation message (stderr)
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |