Submission #216687

#TimeUsernameProblemLanguageResultExecution timeMemory
216687combi1k1Capital City (JOI20_capital_city)C++14
100 / 100
398 ms46944 KiB
#include<bits/stdc++.h> using namespace std; #define ll long long #define ld double #define sz(x) (int)x.size() #define all(x) x.begin(),x.end() #define pb emplace_back #define X first #define Y second const int N = 2e5 + 5; typedef pair<int,int> ii; int p[N]; int s[N]; int init(int n) { iota(p + 1,p + 1 + n,1); fill(s + 1,s + 1 + n,1); return 1; } int lead(int x) { return p[x] == x ? x : p[x] = lead(p[x]); } int join(int x,int y) { x = lead(x); y = lead(y); if (x == y) return 0; if (s[x] < s[y]) swap(x,y); p[y] = x; s[x] += s[y]; return 1; } int c[N]; int a[N]; int b[N]; vector<int> S[N]; vector<int> g[N]; int nCh[N], pos[N]; int led[N], par[N]; int arr[N], tot = 0; void dfs(int u,int p) { par[u] = p; nCh[u] = 1; for(int v : g[u]) if (v != p) { dfs(v,u); nCh[u] += nCh[v]; } } void hld(int u,int ok) { if (ok) led[u] = u; else led[u] = led[par[u]]; pos[u] = ++tot; arr[tot] = u; int B = 0; for(int v : g[u]) if (v != par[u]) if (nCh[B] < nCh[v]) B = v; if (B) hld(B,0); for(int v : g[u]) if (v != par[u] && v != B) hld(v,1); } bool is_anc(int u,int v) { return pos[u] <= pos[v] && pos[u] + nCh[u] >= pos[v] + nCh[v]; } int lca(int u,int v) { while (!is_anc(led[u],v)) u = par[led[u]]; while (!is_anc(led[v],u)) v = par[led[v]]; if (pos[u] > pos[v]) swap(u,v); assert(is_anc(u,v)); return u; } int tr[N << 1]; int ok[N]; vector<int> vec[N]; bool have[N]; bool need[N]; int main() { ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0); int n; cin >> n; int k; cin >> k; for(int i = 1 ; i < n ; ++i) { cin >> a[i]; cin >> b[i]; } for(int i = 1 ; i <= n ; ++i) cin >> c[i]; init(n); for(int i = 1 ; i < n ; ++i) if (c[a[i]] == c[b[i]]) join(a[i],b[i]); for(int i = 1 ; i < n ; ++i) if (c[a[i]] != c[b[i]]) { int x = lead(a[i]); int y = lead(b[i]); g[x].pb(y); g[y].pb(x); } for(int i = 1 ; i <= n ; ++i) if (i == p[i]) S[c[i]].pb(i); dfs(lead(1),0); hld(lead(1),1); auto upd = [&](int p,int v) { for(tr[p += tot - 1] = v ; p > 1 ; p >>= 1) tr[p >> 1] = min(tr[p],tr[p ^ 1]); }; auto get = [&](int l,int r) { l += tot - 1; r += tot; int ans = 1e9; for(; l < r ; l >>= 1, r >>= 1) { if (l & 1) ans = min(ans,tr[l++]); if (r & 1) ans = min(ans,tr[--r]); } return ans; }; for(int i = 1 ; i <= tot ; ++i) upd(i,1e9); for(int i = 1 ; i <= k ; ++i) sort(all(S[i]),[&](int x,int y) { return pos[x] < pos[y]; }); vector<int> perms(k); iota(all(perms),1); sort(all(perms),[&](int x,int y) { int a = lca(S[x][0],S[x].back()); int b = lca(S[y][0],S[y].back()); return pos[a] < pos[b]; }); for(int i : perms) { int R = lca(S[i][0],S[i].back()); int P = pos[R]; for(int x : S[i]) { while (!is_anc(led[x],R)) { P = min(get(pos[led[x]],pos[x]),P); x = par[led[x]]; } P = min(get(pos[R],pos[x]),P); } for(int x : S[i]) upd(pos[x],P); if (P == pos[S[i][0]]) ok[i] = 1; } int ans = 1e9; for(int i = 1 ; i <= k ; ++i) if (ok[i]) { need[i] = 1; queue<int> qu; stack<int> st; for(int x : S[i]) qu.push(x), have[x] = 1; int cur = 0; while (qu.size()) { int u = qu.front(); qu.pop(); st.push(u); if (get(pos[u],pos[u]) != pos[S[i][0]]) { cur = 1e9; while (qu.size()) { st.push(qu.front()); qu.pop(); } break; } if (u != S[i][0]) u = par[u]; if(!have[u]) { have[u] = 1; qu.push(u); if(!need[c[u]]) { need[c[u]] = 1; ++cur; for(int x : S[c[u]]) if(!have[x]) { have[x] = 1; qu.push(x); } } } } while (st.size()) { int u = st.top(); st.pop(); have[u] = 0; need[c[u]] = 0; } if (ans > cur) ans = cur; } cout << ans << endl; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...