Submission #898839

#TimeUsernameProblemLanguageResultExecution timeMemory
898839juliany2Capital City (JOI20_capital_city)C++17
41 / 100
3042 ms67980 KiB
#pragma GCC optimize("O3,unroll-loops") #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt") #include<bits/stdc++.h> using namespace std; using ll = long long; #define all(x) (x).begin(), (x).end() template<class T> struct ST { static constexpr T ID = {(int) 1e9, 0}; // or whatever ID inline T comb(T a, T b) { return min(a, b); } // or whatever function int sz; vector<T> t; void init(int _sz, T val = ID) { t.assign((sz = _sz) * 2, ID); } void init(vector<T> &v) { t.resize((sz = v.size()) * 2); for (int i = 0; i < sz; ++i) t[i + sz] = v[i]; for (int i = sz - 1; i; --i) t[i] = comb(t[i * 2], t[(i * 2) | 1]); } void upd(int i, T x) { for (t[i += sz] = x; i > 1; i >>= 1) t[i >> 1] = comb(t[i], t[i ^ 1]); } T query(int l, int r) { T ql = ID, qr = ID; for (l += sz, r += sz + 1; l < r; l >>= 1, r >>= 1) { if (l & 1) ql = comb(ql, t[l++]); if (r & 1) qr = comb(t[--r], qr); } return comb(ql, qr); } }; struct DSU { vector<int> e; DSU(int sz) { e = vector<int>(sz + 1, -1); } int get(int x) { return e[x] < 0 ? x : e[x] = get(e[x]); } bool same_set(int a, int b) { return get(a) == get(b); } int size(int x) { return -e[get(x)]; } bool unite(int x, int y) { x = get(x), y = get(y); if (x == y) return false; if (e[x] > e[y]) swap(x, y); e[x] += e[y]; e[y] = x; return true; } }; const int N = 2e5 + 7, L = 20; int n, k; vector<int> adj[N], col[N], topo, comp; set<int> active; ST<array<int, 2>> st; int c[N], lift[N][L], dep[N], sz[N], head[N], pos[N], start[N], who[N], timer; int tin[N], tout[N], reach[N]; bool vis[N], done[N]; void dfs(int v = 1, int p = 0) { tin[v] = ++timer; sz[v] = 1; lift[v][0] = p; for (int i = 1; i < L; i++) lift[v][i] = lift[lift[v][i - 1]][i - 1]; for (int &u : adj[v]) { if (u != p) { dep[u] = dep[v] + 1; dfs(u, v); sz[v] += sz[u]; if (adj[v][0] == p || sz[v] > sz[adj[v][0]]) swap(u, adj[v][0]); } } tout[v] = timer; } void dfs_hld(int v = 1, int p = 0) { pos[v] = timer++; for (int u : adj[v]) { if (u != p) { head[u] = (u == adj[v][0] ? head[v] : u); dfs_hld(u, v); } } } int lca(int u, int v) { if (dep[u] > dep[v]) swap(u, v); for (int i = L - 1; ~i; --i) if (dep[v] - (1 << i) >= dep[u]) v = lift[v][i]; if (u == v) return u; for (int i = L - 1; ~i; --i) if (lift[v][i] != lift[u][i]) v = lift[v][i], u = lift[u][i]; return lift[u][0]; } void dfs1(int a); void process(int a, int b) { while (active.lower_bound(a) != active.end()) { int x = *active.lower_bound(a); if (x > b) break; active.erase(x); if (!vis[who[x]]) dfs1(who[x]); } } void query(int a, int b) { for (; head[a] != head[b]; b = lift[head[b]][0]) { if (dep[b] > reach[head[b]]) { process(pos[head[b]], pos[b]); reach[head[b]] = dep[b]; } } process(pos[a], pos[b]); } void dfs1(int a) { vis[a] = 1; for (int v : col[a]) query(start[a], v); topo.push_back(a); } void dfs2(int a) { vis[a] = 1; comp.push_back(a); for (int v : col[a]) { while (st.query(tin[v], tout[v])[0] <= dep[v]) { int x = st.query(tin[v], tout[v])[1]; st.upd(tin[x], {(int) 1e9, 0}); if (!vis[c[x]]) dfs2(c[x]); } } } int main() { cin.tie(0)->sync_with_stdio(false); cin >> n >> k; for (int i = 1; i < n; i++) { int u, v; cin >> u >> v; adj[u].push_back(v); adj[v].push_back(u); } for (int i = 1; i <= n; i++) { cin >> c[i]; col[c[i]].push_back(i); } timer = 0; dfs(); for (int i = 1; i <= k; i++) { start[i] = col[i][0]; for (int x : col[i]) { start[i] = lca(start[i], x); } } timer = 1; head[1] = 1; dfs_hld(); for (int i = 1; i <= n; i++) who[pos[i]] = c[i]; for (int i = 1; i <= n; i++) active.insert(i); memset(reach, -1, sizeof(reach)); for (int i = 1; i <= k; i++) if (!vis[i]) dfs1(i); reverse(all(topo)); st.init(n + 1); memset(vis, 0, sizeof(vis)); memset(head, 0, sizeof(head)); for (int i = 1; i <= n; i++) st.upd(tin[i], {dep[start[c[i]]], i}); DSU dsu(n + 1); int ans = k - 1; for (int v : topo) { if (!vis[v]) { comp.clear(); dfs2(v); int cnt = 0; for (int x : comp) for (int u : col[x]) head[u] = v, cnt++; for (int x : comp) for (int s : col[x]) for (int t : adj[s]) if (head[t] == v) dsu.unite(s, t); if (dsu.size(col[comp[0]][0]) == cnt) ans = min(ans, (int) comp.size() - 1); } } cout << ans << '\n'; return 0; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...