Submission #1173439

#TimeUsernameProblemLanguageResultExecution timeMemory
1173439nguyenkhangninh99Factories (JOI14_factories)C++20
100 / 100
2021 ms168332 KiB

#include<bits/stdc++.h>
#include "factories.h"

using namespace std;

#define ll long long

const int maxn = 5e5 + 5;

vector<pair<int, int>> g[maxn];
vector<pair<int, ll>> adj[maxn];
int h[maxn], up[maxn][22], tin[maxn], out[maxn],  timeDfs;
ll depth[maxn], ans, f[maxn][2];
bool is1[maxn], is2[maxn];

void dfs(int u){
    tin[u] = ++timeDfs;
	for(auto [v, w]: g[u]){
		if(v == up[u][0]) continue;
		
		h[v] = h[u] + 1;
        depth[v] = depth[u] + w;
		up[v][0] = u; 

		for(int j = 1; j <= 21; j++) up[v][j] = up[up[v][j - 1]][j - 1];
		
		dfs(v);
	}
    out[u] = timeDfs;
}
 
int lca(int u, int v){
	if(h[u] != h[v]){
		if(h[u] < h[v]) swap(u, v);
		
		int k = h[u] - h[v];
		
		for(int j = 0; (1 << j) <= k; j++){
			if((k >> j) & 1){
				u = up[u][j];
			}
		}
	}
	
	if(u == v)  return u;
	
	int k = __lg(h[u]);
	
	for(int j = k; j >= 0; j--) {
		if(up[u][j] != up[v][j]){
			u = up[u][j];
			v = up[v][j];
		}
	}
	
	return up[u][0];
}

ll dist(int u, int v){
    int k = lca(u, v);
    return depth[u] + depth[v] - 2 * depth[k];
}

void dp(int u, int pre){
    for(auto [v, w]: adj[u]){
        if(v == pre) continue;
        dp(v, u);
        f[u][0] = min(f[u][0], f[v][0] + w);
        f[u][1] = min(f[u][1], f[v][1] + w);
    }

    if(is1[u]){
        f[u][0] = 0;
        if(is2[u]){
            f[u][1] = 0;
            ans = 0;
        }
        else ans = min(ans, f[u][1]);
    }
    else if(is2[u]){
        f[u][1] = 0;
        ans = min(ans, f[u][0]);
    }
    else ans = min(ans, f[u][0] + f[u][1]);
}
void Init(int N_, int A[], int B[], int D[]){
    for(int i = 1; i <= N_ - 1; i++) {
        int u = A[i - 1], v = B[i - 1], w = D[i - 1];
        u++; v++;
        g[u].push_back({v, w});
        g[v].push_back({u, w});
    }

    dfs(1);
    
    for(int i = 1; i <= N_; i++) f[i][0] = f[i][1] = 1e16;
}
long long Query(int S, int X[], int T, int Y[]){
    int s = S, t = T;
    vector<int> comp1(s), comp2(t);

    vector<int> vert;

    for(int i = 0; i < s; i++) is1[++X[i]] = 1, vert.push_back(X[i]);
    for(int i = 0; i < t; i++) is2[++Y[i]] = 1, vert.push_back(Y[i]);

    sort(vert.begin(), vert.end(), [&](int c, int d) {return tin[c] < tin[d];});

    int k = vert.size();
    for(int i = 0; i < k - 1; i++) vert.push_back(lca(vert[i], vert[i + 1]));

    sort(vert.begin(), vert.end(), [&](int c, int d) {return tin[c] < tin[d];});
    vert.erase(unique(vert.begin(), vert.end()), vert.end());

    stack<int> st;
    for(int u: vert) {
        while(st.size() && !(tin[st.top()] <= tin[u] && tin[u] <= out[st.top()])) st.pop();
        if(st.size()){
            int v = st.top();
            ll w = dist(u, v);
            adj[u].push_back({v, w});
            adj[v].push_back({u, w});
        }
        st.push(u);
    }

    ans = 1e16;
    dp(vert[0], -1);

    for(int u: vert){
        is1[u] = is2[u] = 0;
        f[u][0] = f[u][1] = 1e16;
        adj[u].clear();
    }

    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...