Submission #399801

#TimeUsernameProblemLanguageResultExecution timeMemory
399801azberjibiouMergers (JOI19_mergers)C++17
100 / 100
1008 ms141484 KiB
#include <bits/stdc++.h> #define gibon ios::sync_with_stdio(false); cin.tie(0); #define bp __builtin_popcount #define fir first #define sec second #define pii pair<int, int> #define pll pair<ll, ll> #pragma GCC optimize("O3") #pragma GCC optimize("Ofast") #pragma GCC optimize("unroll-loops") typedef long long ll; using namespace std; int dx[4]={0, 1, 0, -1}, dy[4]={1, 0, -1 , 0}; const int mxN=502000; const int mxM=104; const int mxK=1000000; const ll MOD=1000000007; const ll INF=100000000000001; int N, K; vector <int> v[mxN]; vector <pii> E; vector <int> grp[mxN]; int par[mxN][20], dep[mxN]; int S[mxN]; int root[mxN]; ///lca of all node int uf[mxN]; int deg[mxN]; int ans; void dfs(int now, int pre) { for(int nxt : v[now]) { if(nxt==pre) continue; par[nxt][0]=now; dep[nxt]=dep[now]+1; dfs(nxt, now); } } int findpar(int a) { if(uf[a]==a) return a; return uf[a]=findpar(uf[a]); } int lca(int a, int b) { if(dep[a]<dep[b]) return lca(b, a); for(int i=19;i>=0;i--) { if(dep[a]>=dep[b]+(1<<i)) { a=par[a][i]; } } if(a==b) return a; for(int i=19;i>=0;i--) { if(par[a][i]!=par[b][i]) a=par[a][i], b=par[b][i]; } return par[a][0]; } int main() { gibon cin >> N >> K; for(int i=1;i<=N;i++) uf[i]=i; for(int i=1;i<N;i++) { int a, b; cin >> a >> b; v[a].push_back(b); v[b].push_back(a); E.push_back({a, b}); } dfs(1, -1); for(int i=1;i<=19;i++) for(int j=1;j<=N;j++) par[j][i]=par[par[j][i-1]][i-1]; for(int i=1;i<=N;i++) cin >> S[i]; for(int i=1;i<=N;i++) grp[S[i]].push_back(i); for(int i=1;i<=K;i++) { for(int ele : grp[i]) { if(!root[i]) { root[i]=ele; continue; } root[i]=lca(root[i], ele); } } for(int i=1;i<=K;i++) { for(int ele : grp[i]) { int now=ele; while(findpar(now)!=findpar(root[i])) { now=findpar(now); uf[now]=par[now][0]; now=par[now][0]; } } } for(pii ele : E) { int a=findpar(ele.fir), b=findpar(ele.sec); if(a==b) continue; deg[a]++; deg[b]++; } for(int i=1;i<=N;i++) { if(deg[i]==1) ans++; } cout << (ans+1)/2; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...