제출 #1280739

#제출 시각아이디문제언어결과실행 시간메모리
1280739MasterMoonLOSTIKS (INOI20_lostiks)C++17
100 / 100
1940 ms455676 KiB
#include <bits/stdc++.h>
using namespace std;
#define __Master_Moon__ int main()
#define ll long long
#define el "\n"
#define fi first
#define sq(x) (x)*(x)
#define se second
#define pub push_back
#define puf push_front
#define pii pair <int, int>
#define pll pair <long long, long long>
#define piii pair <int, pair <int, int>>
#define iiii pair <int, pair <int, pair <int, int>>>
#define plll pair <long long, pair <long long, long long>>
#define FOR(i, a, b) for(int i = (a);i <=(b);i++)
#define FO(i, a, b) for(int i = (a);i >= (b);i--)
#define REP(i, n) for(int i = 0;i < (n);i++)
long const maxn = 1e6 + 5;
long const lg = 21;
long const MASK = (1 << 20) + 2;
int dp[21][MASK][2],p[maxn][lg],OR[maxn][lg],it;
int dist[100][100],res[100][100],id[maxn],a[100];
int s,t,n,h[maxn],val[21],vertical[21][2],lca;
int ans = INT_MAX,m = 0;
vector<pii> g[maxn];
void dfs(int u,int par)
{
    h[u] = h[par] + 1;
    for(pii x : g[u])
    {
        if(x.fi != par)
        {
            p[x.fi][0] = u;
            OR[x.fi][0] = x.se;
            dfs(x.fi,u);
        }
    }
}
void init()
{
    FOR(i,1,20)
    {
        FOR(j,1,n)
        {
            p[j][i] = p[p[j][i-1]][i-1];
            OR[j][i] = OR[j][i-1] | OR[p[j][i-1]][i-1];
        }
    }
}
int LCA(int u,int v)
{
    if(h[u] < h[v]) swap(u,v);
    int tmp = h[u] - h[v],res = 0;
    FO(i,20,0)
    {
        if(tmp >= (1<<i))
        {
            tmp -= (1<<i);
            res = res | OR[u][i];
            u = p[u][i];
        }
    }
    if(u == v)
    {
        lca = u;
        return res;
    }
    FO(i,20,0)
    {
        if(p[u][i] != p[v][i])
        {
            res = res | OR[u][i] | OR[v][i];
            u = p[u][i];
            v = p[v][i];
        }
    }
    res = res | OR[u][0] | OR[v][0];
    lca = p[u][0];
    return res;
}
void compress()
{
    FOR(i,1,m) id[a[i]] = i;
    FOR(i,1,m)
    {
        FOR(j,1,m)
        {
            res[i][j] = LCA(a[i],a[j]);
            dist[i][j] = h[a[i]] + h[a[j]] - 2*h[lca];  
        }
    }
}
int build(int start,int to,int between,int key,int mask)
{
    int tmp = 0;
    start = id[start];
    to = id[to];
    between = id[between];
    if((res[start][between] | mask) != mask) return -1;
    tmp += dist[start][between];
    mask |= (1<<key); 
    if((res[to][between] | mask) != mask) return -1;
    tmp += dist[to][between];
    return tmp;
}
void solve()
{
    cin >> n >> s >> t;
    memset(dp,-1,sizeof dp);
    FOR(i,1,n-1)
    {
        int u,v,w,tmp = 0;
        cin >> u >> v >> w;
        if(w)
        {
            tmp = (1<<it);
            val[it] = w;
            vertical[it][0] = u;
            vertical[it][1] = v;
            if(id[u] == 0) {m++; a[m] = u;}
            if(id[v] == 0) {m++; a[m] = v;}
            if(id[w] == 0) {m++; a[m] = w;}
            id[u] = id[v] = id[w] = 1;
            it++;
        }
        g[u].pub({v,tmp});
        g[v].pub({u,tmp});
    }
    dfs(1,0);
    init();
    if(LCA(s,t) == 0)
    {
        cout << h[s] + h[t] - 2*h[lca];
        return;
    }
    if(id[s] == 0) {m++; a[m] = s;}
    if(id[t] == 0) {m++; a[m] = t;}
    compress();
    REP(i,it)
    {
        int u = vertical[i][0],v = vertical[i][1];
        dp[i][(1<<i)][0] = build(s,u,val[i],i,0);
        dp[i][(1<<i)][1] = build(s,v,val[i],i,0);
    } 
    REP(mask,(1<<it))
    {
        if(__builtin_popcount(mask) <= 1) continue;
        REP(i,it)
        {
            if(mask&(1<<i))
            {
                REP(j,it)
                {
                    if(j != i && mask&(1<<j))
                    {
                        if(dp[j][mask^(1<<i)][1] != -1)
                        {
                            int tmp = build(vertical[j][1],vertical[i][0],val[i],i,mask^(1<<i));
                            int mtp = build(vertical[j][1],vertical[i][1],val[i],i,mask^(1<<i));
                            if(tmp != -1)
                            {
                                tmp += dp[j][mask^(1<<i)][1];
                                if(dp[i][mask][0] == -1) dp[i][mask][0] = tmp;
                                else dp[i][mask][0] = min(dp[i][mask][0],tmp);
                            }
                            if(mtp != -1)
                            {
                                mtp += dp[j][mask^(1<<i)][1];
                                if(dp[i][mask][1] == -1) dp[i][mask][1] = mtp;
                                else dp[i][mask][1] = min(dp[i][mask][1],mtp);
                            }
                        }
                        if(dp[j][mask^(1<<i)][0] != -1)
                        {
                            int tmp = build(vertical[j][0],vertical[i][0],val[i],i,mask^(1<<i));
                            int mtp = build(vertical[j][0],vertical[i][1],val[i],i,mask^(1<<i));
                            if(tmp != -1)
                            {
                                tmp += dp[j][mask^(1<<i)][0];
                                if(dp[i][mask][0] == -1) dp[i][mask][0] = tmp;
                                else dp[i][mask][0] = min(dp[i][mask][0],tmp);
                            }
                            if(mtp != -1)
                            {
                                mtp += dp[j][mask^(1<<i)][0];
                                if(dp[i][mask][1] == -1) dp[i][mask][1] = mtp;
                                else dp[i][mask][1] = min(dp[i][mask][1],mtp);
                            }
                        }
                    }
                }
                if(dp[i][mask][0] != -1 && (res[id[vertical[i][0]]][id[t]] | mask) == mask) 
                    ans = min(ans,dp[i][mask][0] + dist[id[vertical[i][0]]][id[t]]);
                if(dp[i][mask][1] != -1 && (res[id[vertical[i][1]]][id[t]] | mask) == mask) 
                    ans = min(ans,dp[i][mask][1] + dist[id[vertical[i][1]]][id[t]]);
            }
        }
    }
    if(ans == INT_MAX) ans = -1;
    cout << ans;
}
__Master_Moon__
{
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    solve();
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...