제출 #236812

#제출 시각아이디문제언어결과실행 시간메모리
236812MvCMergers (JOI19_mergers)C++11
100 / 100
1103 ms127680 KiB
#pragma GCC optimize("O3") #pragma GCC optimize("unroll-loops") #include <bits/stdc++.h> #define rc(x) return cout<<x<<endl,0 #define pb push_back #define mkp make_pair #define in insert #define er erase #define fd find #define fr first #define sc second #define all(x) x.begin(),x.end() typedef long long ll; typedef long double ld; const ll INF=0x3f3f3f3f3f3f3f3f; const ll llinf=(1LL<<62); const int inf=(1<<30); const int nmax=5e5+50; const ll mod=1e9+7; using namespace std; int x,y,i,n,up[nmax][20],tin[nmax],tout[nmax],tt,lc[nmax],sz[nmax],pr[nmax],deg[nmax],rs,j,lvl[nmax],c,k; vector<int>a[nmax],vc[nmax]; void dfs(int x,int p) { tin[x]=++tt; up[x][0]=p; lvl[x]=lvl[p]+1; for(int i=1;i<20;i++)up[x][i]=up[up[x][i-1]][i-1]; for(int i=0;i<(int)a[x].size();i++)if(a[x][i]!=p)dfs(a[x][i],x); tout[x]=++tt; } int anc(int x,int y) { return tin[x]<=tin[y] && tout[x]>=tout[y]; } int lca(int x,int y) { if(anc(x,y))return x; if(anc(y,x))return y; for(int i=19;i>=0;i--)if(!anc(up[x][i],y))x=up[x][i]; return up[x][0]; } int fnd(int x) { if(pr[x]==x)return x; return pr[x]=fnd(pr[x]); } void uni(int x,int y) { x=fnd(x),y=fnd(y); if(x==y)return; if(sz[x]<sz[y])swap(x,y); pr[y]=x; sz[x]+=sz[y]; } void bld(int x,int p) { for(int i=0;i<(int)a[x].size();i++) { int y=a[x][i]; if(y==p)continue; bld(y,x); } lc[p]=max(lc[p],lc[x]-1); if(lc[x])uni(p,x); } int main() { //freopen("sol.in","r",stdin); //freopen("sol.out","w",stdout); //mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); ios_base::sync_with_stdio(false);cin.tie(0);cerr.tie(0);cout.tie(0); cin>>n>>k; for(i=1;i<n;i++) { cin>>x>>y; a[x].pb(y); a[y].pb(x); } dfs(1,1); for(i=1;i<=n;i++) { cin>>c; pr[i]=i; sz[i]=1; vc[c].pb(i); } for(i=1;i<=k;i++) { if(vc[i].empty())continue; x=vc[i][0]; for(j=1;j<(int)vc[i].size();j++)x=lca(x,vc[i][j]); for(j=0;j<(int)vc[i].size();j++)lc[vc[i][j]]=lvl[vc[i][j]]-lvl[x]; } bld(1,1); for(i=2;i<=n;i++) { if(lc[i])continue; deg[fnd(i)]++; deg[fnd(up[i][0])]++; } for(i=1;i<=n;i++)if(deg[i]==1)rs++; cout<<(rs+1)/2<<endl; return 0; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...