제출 #1269688

#제출 시각아이디문제언어결과실행 시간메모리
1269688_callmelucianLOSTIKS (INOI20_lostiks)C++17
0 / 100
66 ms106824 KiB
#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using ld = long double;
using pl = pair<ll,ll>;
using pii = pair<int,int>;
using tpl = tuple<int,int,int>;

#define all(a) a.begin(), a.end()
#define filter(a) a.erase(unique(all(a)), a.end())

const int mn = 1e6 + 6;
int dp[1 << 20][40], nodeA[mn], nodeB[mn], nodeLock[mn], dstA[20][20], dstB[20][20], mskA[20][20], mskB[20][20];
int spt[mn][20], chain[mn], sz[mn], num[mn], depth[mn], par[mn], timeDfs;

vector<int> adj[mn];

int szDfs (int u, int p) {
    sz[u] = 1;
    for (int v : adj[u])
        if (v != p) sz[u] += szDfs(v, u);
    return sz[u];
}

void dfs (int u, int p, int d, bool toP) {
    if (u == 1) szDfs(u, p);
    chain[u] = (toP ? chain[p] : u), par[u] = p;
    depth[u] = d, num[u] = ++timeDfs;
    if (adj[u].size() == (u > 1)) return;

    // process biggest subtree
    int big = *max_element(all(adj[u]), [&] (int i, int j) {
        return (i == p || j == p ? i == p : sz[i] < sz[j]);
    });
    dfs(big, u, d + 1, 1);

    // process other subtrees
    for (int v : adj[u])
        if (v != p && v != big) dfs(v, u, d + 1, 0);
}

int lca (int a, int b) {
    while (chain[a] != chain[b]) {
        int ap = par[chain[a]], bp = par[chain[b]];
        if (depth[ap] == depth[bp]) a = ap, b = bp;
        else if (depth[ap] > depth[bp]) a = ap;
        else if (depth[bp] > depth[ap]) b = bp;
    }
    return (depth[a] < depth[b] ? a : b);
}

int query (int u, int v) {
    if (u > v) return 0;
    int p = 31 - __builtin_clz(v - u + 1);
    return spt[u][p] | spt[v - (1 << p) + 1][p];
}

int getMask (int a, int b) {
    int mask = 0;

    function<void(int&)> lift = [&] (int &u) {
        mask |= query(num[chain[u]] + 1, num[u]);
        u = par[chain[u]];
    };

    while (chain[a] != chain[b]) {
        int ap = par[chain[a]], bp = par[chain[b]];
        if (depth[ap] == depth[bp]) lift(a), lift(b);
        else if (depth[ap] > depth[bp]) lift(a);
        else if (depth[bp] > depth[ap]) lift(b);
    }
    if (depth[a] > depth[b]) swap(a, b);
    return mask | query(num[a] + 1, num[b]);
}

bool check (int u, int v, int mask) {
    int needMask = getMask(u, v);
    return (needMask & mask) == needMask;
}

bool checkFast (int sub, int mask) { return (sub & mask) == sub; }

int dist (int u, int v) {
    return depth[u] + depth[v] - 2 * depth[lca(u, v)];
}

int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);

    /// read input
    int n, S, T, counter = 0;
    cin >> n >> S >> T;

    vector<pii> locked;
    for (int i = 1; i < n; i++) {
        int a, b, c; cin >> a >> b >> c;
        adj[a].push_back(b);
        adj[b].push_back(a);
        if (c) {
            locked.emplace_back(a, b);
            nodeA[counter] = a, nodeB[counter] = b;
            nodeLock[counter++] = c;
        }
    }
    dfs(1, 0, 1, 0);

    /// build sparse table
    for (int i = 0; i < counter; i++) {
        int a, b; tie(a, b) = locked[i];
        if (depth[a] < depth[b]) swap(a, b);
        spt[num[a]][0] = (1 << i);
    }

    for (int s = 1; (1 << s) <= n; s++) {
        int p = s - 1;
        for (int i = 1; i + (1 << s) - 1 <= n; i++)
            spt[i][s] = spt[i][p] | spt[i + (1 << p)][p];
    }

    /// pre-calculations
    for (int i = 0; i < counter; i++) {
        for (int j = 0; j < counter; j++) {
            dstA[i][j] = dist(nodeLock[i], nodeA[j]);
            dstB[i][j] = dist(nodeLock[i], nodeB[j]);
            mskA[i][j] = getMask(nodeLock[i], nodeA[j]);
            mskB[i][j] = getMask(nodeLock[i], nodeB[j]);
        }
    }


    /// setup base cases for DP
    for (int mask = 0; mask < (1 << counter); mask++)
        for (int j = 0; j < 2 * counter; j++) dp[mask][j] = INT_MAX;
    for (int i = 0; i < counter; i++) {
        int mask = (1 << i);
        if (check(S, nodeLock[i], 0)) {
            if (checkFast(mskA[i][i], 0)) dp[mask][i] = dist(S, nodeLock[i]) + dstA[i][i];
            if (checkFast(mskB[i][i], 0)) dp[mask][i + counter] = dist(S, nodeLock[i]) + dstB[i][i];
        }
    }

    /// run DP
    int full = (1 << counter) - 1, ans = INT_MAX;
    for (int mask = 1; mask < (1 << counter); mask++) {
        for (int sub = mask; sub; sub -= sub & (-sub)) {
            int last = __builtin_ctz(sub);
            if (dp[mask][last] == INT_MAX && dp[mask][last + counter] == INT_MAX) continue;

            for (int sub2 = mask ^ full; sub2; sub2 -= sub2 & (-sub2)) {
                int nxt = __builtin_ctz(sub2), dA = dstA[nxt][nxt], dB = dstB[nxt][nxt];
                bool checkA = checkFast(mskA[nxt][nxt], mask), checkB = checkFast(mskB[nxt][nxt], mask);

                if (checkFast(mskA[nxt][last], mask) && dp[mask][last] != INT_MAX) {
                    int cur = dp[mask][last] + dstA[nxt][last];
                    if (checkA)
                        dp[mask | (1 << nxt)][nxt] = min(dp[mask | (1 << nxt)][nxt], cur + dA);
                    if (checkB)
                        dp[mask | (1 << nxt)][nxt + counter] = min(dp[mask | (1 << nxt)][nxt + counter], cur + dB);
                }

                if (checkFast(mskB[nxt][last], mask) && dp[mask][last + counter] != INT_MAX) {
                    int cur = dp[mask][last + counter] + dstB[nxt][last];
                    if (checkA)
                        dp[mask | (1 << nxt)][nxt] = min(dp[mask | (1 << nxt)][nxt], cur + dA);
                    if (checkB)
                        dp[mask | (1 << nxt)][nxt + counter] = min(dp[mask | (1 << nxt)][nxt + counter], cur + dB);
                }
            }

            if (check(nodeA[last], T, mask) && dp[mask][last] != INT_MAX)
                ans = min(ans, dp[mask][last] + dist(nodeA[last], T));
            if (check(nodeB[last], T, mask) && dp[mask][last + counter] != INT_MAX)
                ans = min(ans, dp[mask][last + counter] + dist(nodeB[last], T));
        }
    }
    cout << ans << "\n";

    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...