Submission #901607

#TimeUsernameProblemLanguageResultExecution timeMemory
901607Shayan86Mergers (JOI19_mergers)C++17
100 / 100
831 ms202836 KiB
#include <bits/stdc++.h> using namespace std; #pragma GCC optimize("O3,unroll-loops") // #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt") // Ofast, O0, O1, O2, O3, unroll-loops, fast-math, trapv typedef long long ll; typedef pair<ll, ll> pll; typedef pair<int, int> pii; #define Mp make_pair #define sep ' ' #define endl '\n' #define F first #define S second #define pb push_back #define all(x) (x).begin(),(x).end() #define kill(res) cout << res << '\n', exit(0); #define set_dec(x) cout << fixed << setprecision(x); #define fast_io ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0); #define file_io freopen("input.txt", "r", stdin) ; freopen("output.txt", "w", stdout); mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); const ll L = 22; const ll N = 5e5 + 50; const ll Mod = 1e9 + 7; ll n, k, d[N], h[N], par[N][L], st[N], val[N], timer, sum[N], cnt, col[N], deg[N]; bool mark[N]; vector<int> adj[N], ost[N]; void dfs(int v, int p = 0){ par[v][0] = p; for(int i = 1; i < L; i++){ if(!par[v][i-1]) continue; par[v][i] = par[par[v][i-1]][i-1]; } st[v] = ++timer; val[timer] = v; for(int u: adj[v]){ if(u != p){ h[u] = h[v] + 1; dfs(u, v); } } } int getPar(int v, int k){ for(int i = 0; i < L; i++){ if(k & (1 << i)) v = par[v][i]; } return v; } int lca(int v, int u){ if(h[v] < h[u]) swap(u, v); v = getPar(v, h[v] - h[u]); if(v == u) return v; for(int i = L-1; i >= 0; i--){ if(par[v][i] != par[u][i]){ v = par[v][i]; u = par[u][i]; } } return par[v][0]; } void pre(int v, int p = 0){ for(int u: adj[v]){ if(u == p) continue; pre(u, v); sum[v] += sum[u]; } } void ff(int v){ col[v] = cnt; mark[v] = true; for(int u: adj[v]){ if(h[u] > h[v] && sum[u] == 0) continue; if(h[u] < h[v] && sum[v] == 0) continue; if(mark[u]) continue; ff(u); } } int main(){ fast_io; cin >> n >> k; int u, v; for(int i = 1; i < n; i++){ cin >> u >> v; adj[u].pb(v); adj[v].pb(u); } for(int i = 1; i <= n; i++) cin >> d[i]; dfs(1); for(int i = 1; i <= n; i++) ost[d[i]].pb(i); for(int i = 1; i <= k; i++){ if(ost[i].size() <= 1) continue; ll mn = n+1, mx = 0; for(int j: ost[i]){ mn = min(mn, st[j]); mx = max(mx, st[j]); } int top = lca(val[mn], val[mx]); for(int j: ost[i]){ sum[j]++; sum[top]--; } } pre(1); for(int i = 1; i <= n; i++){ if(!mark[i]){ cnt++; ff(i); } } for(int i = 1; i <= n; i++){ for(int j: adj[i]){ if(h[i] < h[j] && sum[j]) continue; if(h[i] > h[j] && sum[i]) continue; deg[col[i]]++; deg[col[j]]++; } } ll leaf = 0; for(int i = 1; i <= cnt; i++) if(deg[i] == 2) leaf++; cout << (leaf+1)/2; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...