#include<bits/stdc++.h>
#include "factories.h"
using namespace std;
#define int long long
int n,q,tin[500005], tout[500005], timer = -1, s, t, w[500005], dp[2][500005], col[500005];
vector<pair<int,int>> adj[500005], adj2[500005];
pair<int, int> up[20][500005];
vector<int> euler;
int res = 1e18;
bool cmp(int a, int b) {
return tin[a] < tin[b];
}
void dfs(int u, int par) {
tin[u] = ++timer;
euler.push_back(u);
for (pair<int, int> node: adj[u]) {
if(node.first == par) continue;
w[node.first] = w[u] + node.second;
dfs(node.first, u);
timer++;
euler.push_back(u);
}
tout[u] = timer;
}
void build() {
for (int i = 0; i < euler.size(); i++) up[0][i] = {tin[euler[i]], euler[i]};
for (int i = 1; i < 20; i++) {
for (int j = 0; j + (1<<i) - 1 < euler.size(); j++) {
up[i][j] = min(up[i-1][j], up[i-1][j+(1<<(i-1))]);
}
}
}
pair<int, int> get(int u, int v) {
int lg = __lg(v - u + 1);
return min(up[lg][u], up[lg][v-(1<<lg)+1]);
}
void dfs2(int u, int par) {
dp[0][u] = dp[1][u] = 1e18;
for (pair<int, int> node: adj2[u]) {
if(node.first == par) continue;
dfs2(node.first, u);
dp[0][u] = min(dp[0][u], dp[0][node.first] + node.second);
dp[1][u] = min(dp[1][u], dp[1][node.first] + node.second);
if(col[node.first] == 1) dp[0][u] = min(dp[0][u], w[node.first] - w[u]);
if(col[node.first] == 2) dp[1][u] = min(dp[1][u], w[node.first] - w[u]);
}
if(col[u] == 1) res = min(res, dp[1][u]);
if(col[u] == 2) res = min(res, dp[0][u]);
res = min(res, dp[1][u] + dp[0][u]);
}
void Init(int n, int a[], int b[], int d[]) {
for (int i = 1; i < n; i++) {
a[i]++, b[i]++;
adj[a[i]].push_back({b[i], d[i]});
adj[b[i]].push_back({a[i], d[i]});
}
dfs(1, -1);
build();
}
long long Query(int S, int X[], int T, int Y[]) {
vector<int> v, v2;
int root, mx = 1e18;
for (int i = 0; i < S; i++) {X[i]++, col[X[i]] = 1, v.push_back(X[i]); if(mx > tin[X[i]]) mx = tin[X[i]], root = X[i];}
for (int i = 0; i < T; i++) {Y[i]++, col[Y[i]] = 2, v.push_back(Y[i]); if(mx > tin[Y[i]]) mx = tin[Y[i]], root = Y[i];}
sort(v.begin(), v.end(), cmp);
v2 = v;
for (int i = 0; i < v.size() - 1; i++) {
pair<int, int> lca = get(tin[v[i]], tin[v[i+1]]);
v2.push_back(lca.second);
if(mx > lca.first) mx = lca.first, root = lca.second;
}
sort(v2.begin(), v2.end(), cmp);
stack<int> s;
s.push(v2[0]);
for (int i = 1; i < v2.size(); i++) {
int cur = s.top(), p = v2[i];
if(tin[p] == tin[cur]) continue;
while(tin[cur] > tin[p] || tout[cur] < tin[p]) {
s.pop();
cur = s.top();
}
adj2[cur].push_back({p, w[p] - w[cur]});
s.push(p);
}
res = 1e18;
dfs2(root, -1);
for (int i = 0; i < S; i++) col[X[i]] = 0;
for (int i = 0; i < T; i++) col[Y[i]] = 0;
return res;
}
Compilation message
factories.cpp: In function 'void build()':
factories.cpp:27:20: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
27 | for (int i = 0; i < euler.size(); i++) up[0][i] = {tin[euler[i]], euler[i]};
| ~~^~~~~~~~~~~~~~
factories.cpp:29:34: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
29 | for (int j = 0; j + (1<<i) - 1 < euler.size(); j++) {
| ~~~~~~~~~~~~~~~^~~~~~~~~~~~~~
factories.cpp: In function 'long long int Query(long long int, long long int*, long long int, long long int*)':
factories.cpp:68:20: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
68 | for (int i = 0; i < v.size() - 1; i++) {
| ~~^~~~~~~~~~~~~~
factories.cpp:76:20: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
76 | for (int i = 1; i < v2.size(); i++) {
| ~~^~~~~~~~~~~
factories.cpp:87:6: warning: 'root' may be used uninitialized in this function [-Wmaybe-uninitialized]
87 | dfs2(root, -1);
| ~~~~^~~~~~~~~~
/usr/bin/ld: /tmp/ccQ2NBuG.o: in function `main':
grader.cpp:(.text.startup+0x37d): undefined reference to `Init(int, int*, int*, int*)'
/usr/bin/ld: grader.cpp:(.text.startup+0x412): undefined reference to `Query(int, int*, int, int*)'
collect2: error: ld returned 1 exit status