Submission #1269680

#TimeUsernameProblemLanguageResultExecution timeMemory
1269680_callmelucianLOSTIKS (INOI20_lostiks)C++17
0 / 100
65 ms37192 KiB
#include <bits/stdc++.h> using namespace std; using ll = long long; using ld = long double; using pl = pair<ll,ll>; using pii = pair<int,int>; using tpl = tuple<int,int,int>; #define all(a) a.begin(), a.end() #define filter(a) a.erase(unique(all(a)), a.end()) const int mn = 1e6 + 6; int dp[1 << 20][40], nodeA[mn], nodeB[mn], nodeLock[mn], dstA[20][20], dstB[20][20], mskA[20][20], mskB[20][20]; int spt[mn][20], chain[mn], sz[mn], num[mn], depth[mn], par[mn], timeDfs; vector<int> adj[mn]; int szDfs (int u, int p) { sz[u] = 1; for (int v : adj[u]) if (v != p) sz[u] += szDfs(v, u); return sz[u]; } void dfs (int u, int p, int d, bool toP) { if (u == 1) szDfs(u, p); chain[u] = (toP ? chain[p] : u), par[u] = p; depth[u] = d, num[u] = ++timeDfs; if (adj[u].size() == (u > 1)) return; // process biggest subtree int big = *max_element(all(adj[u]), [&] (int i, int j) { return (i == p || j == p ? i == p : sz[i] < sz[j]); }); dfs(big, u, d + 1, 1); // process other subtrees for (int v : adj[u]) if (v != p && v != big) dfs(v, u, d + 1, 0); } int lca (int a, int b) { while (chain[a] != chain[b]) { int ap = par[chain[a]], bp = par[chain[b]]; if (depth[ap] == depth[bp]) a = ap, b = bp; else if (depth[ap] > depth[bp]) a = ap; else if (depth[bp] > depth[ap]) b = bp; } return (depth[a] < depth[b] ? a : b); } int query (int u, int v) { if (u > v) return 0; int p = 31 - __builtin_clz(v - u + 1); return spt[u][p] | spt[v - (1 << p) + 1][p]; } int getMask (int a, int b) { int mask = 0; function<void(int&)> lift = [&] (int &u) { mask |= query(num[chain[u]] + 1, num[u]); u = par[chain[u]]; }; while (chain[a] != chain[b]) { int ap = par[chain[a]], bp = par[chain[b]]; if (depth[ap] == depth[bp]) lift(a), lift(b); else if (depth[ap] > depth[bp]) lift(a); else if (depth[bp] > depth[ap]) lift(b); } if (depth[a] > depth[b]) swap(a, b); return mask | query(num[a] + 1, num[b]); } bool check (int u, int v, int mask) { int needMask = getMask(u, v); return (needMask & mask) == needMask; } bool checkFast (int sub, int mask) { return (sub & mask) == sub; } int dist (int u, int v) { return depth[u] + depth[v] - 2 * depth[lca(u, v)]; } int main() { ios::sync_with_stdio(0); cin.tie(0); /// read input int n, S, T, counter = 0; cin >> n >> S >> T; vector<pii> locked; for (int i = 1; i < n; i++) { int a, b, c; cin >> a >> b >> c; adj[a].push_back(b); adj[b].push_back(a); if (c) { locked.emplace_back(a, b); nodeA[counter] = a, nodeB[counter] = b; nodeLock[counter++] = c; } } dfs(1, 0, 1, 0); /// build sparse table for (int i = 0; i < counter; i++) { int a, b; tie(a, b) = locked[i]; if (depth[a] < depth[b]) swap(a, b); spt[num[a]][0] = (1 << i); } for (int s = 1; (1 << s) <= n; s++) { int p = s - 1; for (int i = 1; i + (1 << s) - 1 <= n; i++) spt[i][s] = spt[i][p] | spt[i + (1 << p)][p]; } /// pre-calculations for (int i = 0; i < counter; i++) { for (int j = 0; j < counter; j++) { dstA[i][j] = dist(nodeLock[i], nodeA[j]); dstB[i][j] = dist(nodeLock[i], nodeB[j]); mskA[i][j] = getMask(nodeLock[i], nodeA[j]); mskB[i][j] = getMask(nodeLock[i], nodeB[j]); } } /// setup base cases for DP for (int mask = 0; mask < (1 << counter); mask++) for (int j = 0; j < 2 * counter; j++) dp[mask][j] = dp[mask][j + n] = INT_MAX; for (int i = 0; i < counter; i++) { int mask = (1 << i); if (check(S, nodeLock[i], 0)) { if (checkFast(mskA[i][i], 0)) dp[mask][i] = dist(S, nodeLock[i]) + dstA[i][i]; if (checkFast(mskB[i][i], 0)) dp[mask][i + n] = dist(S, nodeLock[i]) + dstB[i][i]; } } /// run DP int full = (1 << counter) - 1, ans = INT_MAX; for (int mask = 1; mask < (1 << counter); mask++) { for (int sub = mask; sub; sub -= sub & (-sub)) { int last = __builtin_ctz(sub); if (dp[mask][last] == INT_MAX && dp[mask][last + n] == INT_MAX) continue; for (int sub2 = mask ^ full; sub2; sub2 -= sub2 & (-sub2)) { int nxt = __builtin_ctz(sub2), dA = dstA[nxt][nxt], dB = dstB[nxt][nxt]; bool checkA = checkFast(mskA[nxt][nxt], mask), checkB = checkFast(mskB[nxt][nxt], mask); if (checkFast(mskA[nxt][last], mask) && dp[mask][last] != INT_MAX) { int cur = dp[mask][last] + dstA[nxt][last]; if (checkA) dp[mask | (1 << nxt)][nxt] = min(dp[mask | (1 << nxt)][nxt], cur + dA); if (checkB) dp[mask | (1 << nxt)][nxt + n] = min(dp[mask | (1 << nxt)][nxt + n], cur + dB); } if (checkFast(mskB[nxt][last], mask) && dp[mask][last + n] != INT_MAX) { int cur = dp[mask][last + n] + dstB[nxt][last]; if (checkA) dp[mask | (1 << nxt)][nxt] = min(dp[mask | (1 << nxt)][nxt], cur + dA); if (checkB) dp[mask | (1 << nxt)][nxt + n] = min(dp[mask | (1 << nxt)][nxt + n], cur + dB); } } if (check(nodeA[last], T, mask) && dp[mask][last] != INT_MAX) ans = min(ans, dp[mask][last] + dist(nodeA[last], T)); if (check(nodeB[last], T, mask) && dp[mask][last + n] != INT_MAX) ans = min(ans, dp[mask][last + n] + dist(nodeB[last], T)); } } cout << ans << "\n"; return 0; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...