Submission #1289568

#TimeUsernameProblemLanguageResultExecution timeMemory
1289568shidou26LOSTIKS (INOI20_lostiks)C++20
100 / 100
1611 ms467736 KiB
// #pragma GCC optimize("O3,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")

#include <bits/stdc++.h>
using namespace std;

#ifdef KURUMI
    #include "algo/debug.h"
#endif

#define endl '\n'
#define fi first
#define se second
#define sz(v) (int)v.size()
#define all(v) v.begin(), v.end()
#define filter(v) v.resize(unique(all(v)) - v.begin())
#define dbg(x) "[" #x << " = " << x << "]" 

mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
template<typename T1, typename T2> T2 rand(T1 l, T2 r) {
    return uniform_int_distribution<T2>(l, r)(rng);
}
template<typename T1, typename T2> T2 wrand(T1 l, T2 r, int seed) {
    if(seed == 0) return rand(l, r);
    else return (seed > 0 ? max(rand(l, r), wrand(l, r, seed - 1)) : min(rand(l, r), wrand(l, r, seed + 1)));
}

template<typename T> bool maximize(T &a, T b) {
    if(a < b) {
        a = b;
        return true; 
    }else return false;
}
template<typename T> bool minimize(T &a, T b) {
    if(a > b) {
        a = b;
        return true;
    }else return false;
}

typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, int> pli;
typedef pair<ll, ll> pll;
typedef tuple<int, int, int> tp3;

int read_int(){
    int result = 0;
    char ch;
    ch = getchar();
    while (true) {
        if (ch >= '0' && ch <= '9') break;
        ch = getchar();
    }
    result = ch-'0';
    while (true) {
        ch = getchar();
        if (ch < '0' || ch > '9') break;
        result = result*10 + (ch - '0');
    }
    return result;
}

const int N = 1e6 + 3;
const int M = 20;

int n, m, s, t;
vector<int> useful, only;
vector<tp3> edge;
vector<int> adj[N];

bool rolled = false;
struct DisjointSet {
    vector<int> lab, save;
    vector<pair<int&, int>> history;

    DisjointSet () {}
    DisjointSet (int n) : lab(n + 3, -1) {}

    int root(int u) {
        if(lab[u] < 0) return u;
        return (rolled ? root(lab[u]) : lab[u] = root(lab[u]));
    }

    bool unite(int u, int v) {
        u = root(u); v = root(v);
        if(u == v) return false;
        if(-lab[u] < -lab[v]) swap(u, v);
        
        if(rolled) {
            history.emplace_back(lab[u], lab[u]);
            history.emplace_back(lab[v], lab[v]);
        }

        lab[u] += lab[v];
        lab[v] = u;
        return true;
    }

    bool same(int u, int v) {
        return root(u) == root(v);
    }

    void persist() {
        save.push_back(sz(history));
    }

    void rollback() {
        int t = save.back(); 
        while(sz(history) > t) {
            history.back().fi = history.back().se;
            history.pop_back();
        }
    }
} dsu;

void input() {
    n = read_int();
    // m = read_int();
    s = read_int();
    t = read_int();
    dsu = DisjointSet(n);

    for(int i = 1; i < n; i++) {
        int u, v, lock; 
        u = read_int();
        v = read_int();
        lock = read_int();

        adj[u].push_back(v);
        adj[v].push_back(u);

        if(lock) {
            m++;
            edge.emplace_back(u, v, lock);
            useful.push_back(u); 
            useful.push_back(v);
            only.push_back(u);
            only.push_back(v);
            only.push_back(lock);
        }else dsu.unite(u, v);
    }
}

const int LOG = 21;
const int INF = 0x3f3f3f3f;

int timer = 0, answer = INF;
int tin[N], h[N], id[N], oid[N], predist[100][100];
int spt[2 * N][LOG], dp[1 << M][2 * M + 3];

int merge(int x, int y) {
    return tin[x] < tin[y] ? x : y;
}

void prepare(int u, int p) {
    tin[u] = ++timer; 
    spt[timer][0] = u;

    for(int v : adj[u]) {
        if(v == p) continue;
        h[v] = h[u] + 1;
        prepare(v, u);
        spt[++timer][0] = u;
    }
}

int lca(int u, int v) {
    int l = tin[u], r = tin[v];
    if(l > r) swap(l, r);

    int b = 31 - __builtin_clz(r - l + 1);
    return merge(spt[l][b], spt[r - (1 << b) + 1][b]);
}

int distance(int u, int v) {
    return h[u] + h[v] - 2 * h[lca(u, v)];
}

void process() {
    prepare(1, -1);
    for(int j = 1; j < LOG; j++) {
        for(int i = 1; i + (1 << j) - 1 <= timer; i++) {
            spt[i][j] = merge(spt[i][j - 1], spt[i + (1 << (j - 1))][j - 1]);
        }
    }

    if(dsu.same(s, t)) return cout << distance(s, t) << endl, void();

    sort(all(useful)); filter(useful);
    int k = sz(useful);
    for(int i = 0; i < k; i++) id[useful[i]] = i;

    sort(all(only)); filter(only);
    for(int i = 0; i < sz(only); i++) oid[only[i]] = i;
    for(int i = 0; i < sz(only); i++) {
        for(int j = 0; j < sz(only); j++) {
            predist[i][j] = distance(only[i], only[j]);
        }
    }

    memset(dp, 0x3f, sizeof(dp));
    for(int i = 0; i < m; i++) {
        int u, v, lock; tie(u, v, lock) = edge[i];
        if(dsu.same(s, lock) && dsu.same(lock, u)) dp[1 << i][id[u]] = distance(s, lock) + distance(lock, u);//, cout << dbg(s) << dbg(lock) << dbg(u) << endl;
        if(dsu.same(s, lock) && dsu.same(lock, v)) dp[1 << i][id[v]] = distance(s, lock) + distance(lock, v);//, cout << dbg(s) << dbg(lock) << dbg(v) << endl;
    }

    rolled = true; dsu.persist();
    for(int mask = 0; mask < (1 << m); mask++) {
        vector<int> one, zero;
        for(int i = 0; i < m; i++) {
            if(mask >> i & 1) {
                one.push_back(i);

                int u, v; tie(u, v, ignore) = edge[i];
                dsu.unite(u, v);
            }
            else zero.push_back(i);
        }

        for(int i : one) {
            int u, v; tie(u, v, ignore) = edge[i];
            int ru = dsu.root(u), rv = dsu.root(v);

            for(int j : zero) {
                int nu, nv, lock; tie(nu, nv, lock) = edge[j];
                int rnu = dsu.root(nu), rnv = dsu.root(nv), rlock = dsu.root(lock);

                if(ru == rlock) {
                    if(rlock == rnu) minimize(dp[mask ^ (1 << j)][id[nu]], dp[mask][id[u]] + predist[oid[u]][oid[lock]] + predist[oid[nu]][oid[lock]]);
                    if(rlock == rnv) minimize(dp[mask ^ (1 << j)][id[nv]], dp[mask][id[u]] + predist[oid[u]][oid[lock]] + predist[oid[nv]][oid[lock]]);
                }

                if(rv == rlock) {
                    if(rlock == rnu) minimize(dp[mask ^ (1 << j)][id[nu]], dp[mask][id[v]] + predist[oid[v]][oid[lock]] + predist[oid[nu]][oid[lock]]);
                    if(rlock == rnv) minimize(dp[mask ^ (1 << j)][id[nv]], dp[mask][id[v]] + predist[oid[v]][oid[lock]] + predist[oid[nv]][oid[lock]]);
                }
            }
        }

        for(int i = 0; i < k; i++) {
            if(dsu.same(useful[i], t)) minimize(answer, dp[mask][i] + distance(useful[i], t));
        }

        dsu.rollback();
    }

    cout << (answer == INF ? -1 : answer) << endl;
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    #define task "TREEMAZE"
    if(fopen(task".inp", "r")) {
        freopen(task".inp", "r", stdin);
        freopen(task".out", "w", stdout);
    }
    
    int testcase = 1; // cin >> testcase;    
    for(int i = 1; i <= testcase; i++) {
        input();
        process();
    }

    cerr << "Saa, watashtachi no deeto hajimemashou" << endl;
    cerr << "Atarashii kiseki wo koko kara hajimeru shining place nee mou ichido kimi to..." << endl;
    
    cerr << "Time elapsed: " << (1000.0 * clock() / CLOCKS_PER_SEC) << "ms.\n";

    return 0;
}

Compilation message (stderr)

Main.cpp: In function 'int main()':
Main.cpp:259:16: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
  259 |         freopen(task".inp", "r", stdin);
      |         ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~
Main.cpp:260:16: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
  260 |         freopen(task".out", "w", stdout);
      |         ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...