제출 #1270443

#제출 시각아이디문제언어결과실행 시간메모리
1270443_callmelucianLOSTIKS (INOI20_lostiks)C++17
100 / 100
1025 ms298580 KiB
#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using ld = long double;
using pl = pair<ll,ll>;
using pii = pair<int,int>;
using tpl = tuple<int,int,int>;

#define all(a) a.begin(), a.end()
#define filter(a) a.erase(unique(all(a)), a.end())

const int mn = 1e6 + 6;
int spt[mn][21], num[mn], chain[mn], depth[mn], par[mn], sz[mn], timeDfs;
int dp[1 << 20][40], preDist[20][40], preMask[20][40];
vector<int> adj[mn];

namespace preprocess {
    int szDfs (int u, int p) {
        sz[u] = 1;
        for (int v : adj[u])
            if (v != p) sz[u] += szDfs(v, u);
        return sz[u];
    }

    void dfs (int u, int p, int d, bool toP) {
        if (u == 1) szDfs(u, p);
        chain[u] = (toP ? chain[p] : u), par[u] = p;
        num[u] = ++timeDfs, depth[u] = d;
        if (adj[u].size() == (u > 1)) return; // leaf

        int big = *max_element(all(adj[u]), [&] (int i, int j) {
            return (i == p || j == p ? i == p : sz[i] < sz[j]);
        });
        dfs(big, u, d + 1, 1);

        for (int v : adj[u])
            if (v != p && v != big) dfs(v, u, d + 1, 0);
    }
};

int querySparse (int l, int r) {
    if (l > r) return 0;
    int p = 31 - __builtin_clz(r - l + 1);
    return spt[l][p] | spt[r - (1 << p) + 1][p];
}

int lca (int a, int b) {
    while (chain[a] != chain[b]) {
        int ap = par[chain[a]], bp = par[chain[b]];
        if (depth[ap] == depth[bp]) a = ap, b = bp;
        else if (depth[ap] > depth[bp]) a = ap;
        else if (depth[bp] > depth[ap]) b = bp;
    }
    return (depth[a] < depth[b] ? a : b);
}

int getMask (int a, int b) {
    int ans = 0;
    function<void(int&)> lift = [&] (int &u) {
        ans |= querySparse(num[chain[u]], num[u]);
        u = par[chain[u]];
    };

    while (chain[a] != chain[b]) {
        int ap = par[chain[a]], bp = par[chain[b]];
        if (depth[ap] == depth[bp]) lift(a), lift(b);
        else if (depth[ap] > depth[bp]) lift(a);
        else if (depth[bp] > depth[ap]) lift(b);
    }
    if (depth[a] > depth[b]) swap(a, b);
    return ans | querySparse(num[a] + 1, num[b]);
}

bool isSub (int sub, int mask) {
    return (sub & mask) == sub;
}

bool check (int u, int v, int mask) {
    return isSub(getMask(u, v), mask);
}

int dist (int u, int v) {
    return depth[u] + depth[v] - 2 * depth[lca(u, v)];
}

int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);

    /// read input and preprocess
    int n, S, T, counter = 0;
    cin >> n >> S >> T;

    vector<tpl> locked;
    for (int i = 1; i < n; i++) {
        int a, b, c; cin >> a >> b >> c;
        adj[a].push_back(b);
        adj[b].push_back(a);
        if (c) {
            locked.emplace_back(a, b, c);
            counter++;
        }
    }
    preprocess::dfs(1, 0, 1, 0);

    /// build sparse table
    for (int i = 0; i < counter; i++) {
        int a, b, c; tie(a, b, c) = locked[i];
        if (depth[a] < depth[b]) swap(a, b);
        spt[num[a]][0] = 1 << i;
    }
    for (int s = 1; (1 << s) <= n; s++) {
        int p = s - 1;
        for (int i = 1; i + (1 << s) - 1 <= n; i++)
            spt[i][s] = spt[i][p] | spt[i + (1 << p)][p];
    }

    /// pre-calculations
    for (int i = 0; i < counter; i++) {
        int c = get<2>(locked[i]);
        for (int j = 0; j < counter; j++) {
            int u, v, w; tie(u, v, w) = locked[j];
            preDist[i][j] = dist(c, u), preDist[i][j + counter] = dist(c, v);
            preMask[i][j] = getMask(c, u), preMask[i][j + counter] = getMask(c, v);
        }
    }

    /// prepare DP base-cases
    if (check(S, T, 0)) return cout << dist(S, T) << "\n", 0;

    for (int mask = 0; mask < (1 << counter); mask++)
        fill(dp[mask], dp[mask] + 2 * counter, INT_MAX);
    for (int i = 0; i < counter; i++) {
        int a, b, c; tie(a, b, c) = locked[i];
        if (!check(S, c, 0)) continue;

        // from S to c to either a or b
        int mask = 1 << i;
        if (check(c, a, 0)) dp[mask][i] = dist(S, c) + preDist[i][i];
        if (check(c, b, 0)) dp[mask][i + counter] = dist(S, c) + preDist[i][i + counter];
    }

    /// run bitmask DP
    int full = (1 << counter) - 1, ans = INT_MAX;
    for (int mask = 1; mask < (1 << counter); mask++) {
        vector<int> inMask, outMask;
        for (int i = 0; i < counter; i++) {
            if (mask >> i & 1) inMask.push_back(i);
            else outMask.push_back(i);
        }

        for (int last : inMask) {
            int a, b, c; tie(a, b, c) = locked[last];
            for (int nxt : outMask) {
                int u, v, w; tie(u, v, w) = locked[nxt];
                int nxtMask = mask | (1 << nxt);
                if (dp[mask][last] != INT_MAX && isSub(preMask[nxt][last], mask)) { // from a to w to either u or v
                    int cur = dp[mask][last] + preDist[nxt][last];
                    if (isSub(preMask[nxt][nxt], mask)) // w to u
                        dp[nxtMask][nxt] = min(dp[nxtMask][nxt], cur + preDist[nxt][nxt]);
                    if (isSub(preMask[nxt][nxt + counter], mask)) // w to v
                        dp[nxtMask][nxt + counter] = min(dp[nxtMask][nxt + counter], cur + preDist[nxt][nxt + counter]);
                }
                if (dp[mask][last + counter] != INT_MAX && isSub(preMask[nxt][last + counter], mask)) { // from b to w to either u or v
                    int cur = dp[mask][last + counter] + preDist[nxt][last + counter];
                    if (isSub(preMask[nxt][nxt], mask)) // w to u
                        dp[nxtMask][nxt] = min(dp[nxtMask][nxt], cur + preDist[nxt][nxt]);
                    if (isSub(preMask[nxt][nxt + counter], mask)) // w to v
                        dp[nxtMask][nxt + counter] = min(dp[nxtMask][nxt + counter], cur + preDist[nxt][nxt + counter]);
                }
            }

            // update answer
            if (dp[mask][last] != INT_MAX && check(a, T, mask))
                ans = min(ans, dp[mask][last] + dist(a, T));
            if (dp[mask][last + counter] != INT_MAX && check(b, T, mask))
                ans = min(ans, dp[mask][last + counter] + dist(b, T));
            
            // if (dp[mask][last] != INT_MAX) cout << "dist " << a << " " << bitset<3>(mask) << " " << dp[mask][last] << "\n";
            // if (dp[mask][last + counter] != INT_MAX) cout << "dist " << b << " " << bitset<3>(mask) << " " << dp[mask][last + counter] << "\n";
        }
    }
    cout << (ans == INT_MAX ? -1 : ans) << "\n";


    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...