Submission #1270443

#TimeUsernameProblemLanguageResultExecution timeMemory
1270443_callmelucianLOSTIKS (INOI20_lostiks)C++17
100 / 100
1025 ms298580 KiB
#include <bits/stdc++.h> using namespace std; using ll = long long; using ld = long double; using pl = pair<ll,ll>; using pii = pair<int,int>; using tpl = tuple<int,int,int>; #define all(a) a.begin(), a.end() #define filter(a) a.erase(unique(all(a)), a.end()) const int mn = 1e6 + 6; int spt[mn][21], num[mn], chain[mn], depth[mn], par[mn], sz[mn], timeDfs; int dp[1 << 20][40], preDist[20][40], preMask[20][40]; vector<int> adj[mn]; namespace preprocess { int szDfs (int u, int p) { sz[u] = 1; for (int v : adj[u]) if (v != p) sz[u] += szDfs(v, u); return sz[u]; } void dfs (int u, int p, int d, bool toP) { if (u == 1) szDfs(u, p); chain[u] = (toP ? chain[p] : u), par[u] = p; num[u] = ++timeDfs, depth[u] = d; if (adj[u].size() == (u > 1)) return; // leaf int big = *max_element(all(adj[u]), [&] (int i, int j) { return (i == p || j == p ? i == p : sz[i] < sz[j]); }); dfs(big, u, d + 1, 1); for (int v : adj[u]) if (v != p && v != big) dfs(v, u, d + 1, 0); } }; int querySparse (int l, int r) { if (l > r) return 0; int p = 31 - __builtin_clz(r - l + 1); return spt[l][p] | spt[r - (1 << p) + 1][p]; } int lca (int a, int b) { while (chain[a] != chain[b]) { int ap = par[chain[a]], bp = par[chain[b]]; if (depth[ap] == depth[bp]) a = ap, b = bp; else if (depth[ap] > depth[bp]) a = ap; else if (depth[bp] > depth[ap]) b = bp; } return (depth[a] < depth[b] ? a : b); } int getMask (int a, int b) { int ans = 0; function<void(int&)> lift = [&] (int &u) { ans |= querySparse(num[chain[u]], num[u]); u = par[chain[u]]; }; while (chain[a] != chain[b]) { int ap = par[chain[a]], bp = par[chain[b]]; if (depth[ap] == depth[bp]) lift(a), lift(b); else if (depth[ap] > depth[bp]) lift(a); else if (depth[bp] > depth[ap]) lift(b); } if (depth[a] > depth[b]) swap(a, b); return ans | querySparse(num[a] + 1, num[b]); } bool isSub (int sub, int mask) { return (sub & mask) == sub; } bool check (int u, int v, int mask) { return isSub(getMask(u, v), mask); } int dist (int u, int v) { return depth[u] + depth[v] - 2 * depth[lca(u, v)]; } int main() { ios::sync_with_stdio(0); cin.tie(0); /// read input and preprocess int n, S, T, counter = 0; cin >> n >> S >> T; vector<tpl> locked; for (int i = 1; i < n; i++) { int a, b, c; cin >> a >> b >> c; adj[a].push_back(b); adj[b].push_back(a); if (c) { locked.emplace_back(a, b, c); counter++; } } preprocess::dfs(1, 0, 1, 0); /// build sparse table for (int i = 0; i < counter; i++) { int a, b, c; tie(a, b, c) = locked[i]; if (depth[a] < depth[b]) swap(a, b); spt[num[a]][0] = 1 << i; } for (int s = 1; (1 << s) <= n; s++) { int p = s - 1; for (int i = 1; i + (1 << s) - 1 <= n; i++) spt[i][s] = spt[i][p] | spt[i + (1 << p)][p]; } /// pre-calculations for (int i = 0; i < counter; i++) { int c = get<2>(locked[i]); for (int j = 0; j < counter; j++) { int u, v, w; tie(u, v, w) = locked[j]; preDist[i][j] = dist(c, u), preDist[i][j + counter] = dist(c, v); preMask[i][j] = getMask(c, u), preMask[i][j + counter] = getMask(c, v); } } /// prepare DP base-cases if (check(S, T, 0)) return cout << dist(S, T) << "\n", 0; for (int mask = 0; mask < (1 << counter); mask++) fill(dp[mask], dp[mask] + 2 * counter, INT_MAX); for (int i = 0; i < counter; i++) { int a, b, c; tie(a, b, c) = locked[i]; if (!check(S, c, 0)) continue; // from S to c to either a or b int mask = 1 << i; if (check(c, a, 0)) dp[mask][i] = dist(S, c) + preDist[i][i]; if (check(c, b, 0)) dp[mask][i + counter] = dist(S, c) + preDist[i][i + counter]; } /// run bitmask DP int full = (1 << counter) - 1, ans = INT_MAX; for (int mask = 1; mask < (1 << counter); mask++) { vector<int> inMask, outMask; for (int i = 0; i < counter; i++) { if (mask >> i & 1) inMask.push_back(i); else outMask.push_back(i); } for (int last : inMask) { int a, b, c; tie(a, b, c) = locked[last]; for (int nxt : outMask) { int u, v, w; tie(u, v, w) = locked[nxt]; int nxtMask = mask | (1 << nxt); if (dp[mask][last] != INT_MAX && isSub(preMask[nxt][last], mask)) { // from a to w to either u or v int cur = dp[mask][last] + preDist[nxt][last]; if (isSub(preMask[nxt][nxt], mask)) // w to u dp[nxtMask][nxt] = min(dp[nxtMask][nxt], cur + preDist[nxt][nxt]); if (isSub(preMask[nxt][nxt + counter], mask)) // w to v dp[nxtMask][nxt + counter] = min(dp[nxtMask][nxt + counter], cur + preDist[nxt][nxt + counter]); } if (dp[mask][last + counter] != INT_MAX && isSub(preMask[nxt][last + counter], mask)) { // from b to w to either u or v int cur = dp[mask][last + counter] + preDist[nxt][last + counter]; if (isSub(preMask[nxt][nxt], mask)) // w to u dp[nxtMask][nxt] = min(dp[nxtMask][nxt], cur + preDist[nxt][nxt]); if (isSub(preMask[nxt][nxt + counter], mask)) // w to v dp[nxtMask][nxt + counter] = min(dp[nxtMask][nxt + counter], cur + preDist[nxt][nxt + counter]); } } // update answer if (dp[mask][last] != INT_MAX && check(a, T, mask)) ans = min(ans, dp[mask][last] + dist(a, T)); if (dp[mask][last + counter] != INT_MAX && check(b, T, mask)) ans = min(ans, dp[mask][last + counter] + dist(b, T)); // if (dp[mask][last] != INT_MAX) cout << "dist " << a << " " << bitset<3>(mask) << " " << dp[mask][last] << "\n"; // if (dp[mask][last + counter] != INT_MAX) cout << "dist " << b << " " << bitset<3>(mask) << " " << dp[mask][last + counter] << "\n"; } } cout << (ans == INT_MAX ? -1 : ans) << "\n"; return 0; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...