#include "factories.h"
#include "bits/stdc++.h"
using namespace std;
const int LOG = 23;
const int MAX_N = 5e5 + 5;
int cnt;
int tp[MAX_N];
int dep[MAX_N];
int jump[LOG][MAX_N];
long long rdis[MAX_N];
bool important[MAX_N];
long long dp[MAX_N][3];
int in[MAX_N], out[MAX_N];
vector<pair<int, int>> g[MAX_N];
vector<pair<int, long long>> vg[MAX_N];
void dfs(int v, int p) {
in[v] = cnt++;
for (auto [u, w] : g[v]) {
if (u == p) {
continue;
}
dep[u] = dep[v] + 1;
rdis[u] = rdis[v] + w;
jump[0][u] = v;
for (int i = 1; i < LOG; ++i) {
jump[i][u] = jump[i - 1][jump[i - 1][u]];
}
dfs(u, v);
}
out[v] = cnt - 1;
}
int lca(int u, int v) {
if (dep[u] < dep[v]) {
swap(u, v);
}
for (int i = 0; i < LOG; ++i) {
if ((dep[u] - dep[v]) >> i & 1) {
u = jump[i][u];
}
}
if (u == v) {
return v;
}
for (int i = LOG - 1; i >= 0; --i) {
if (jump[i][u] != jump[i][v]) {
u = jump[i][u], v = jump[i][v];
}
}
return jump[0][u];
}
bool is_parent(int u, int v) {
return in[v] <= in[u] && out[v] >= out[u];
}
int dis(int u, int v) {
return rdis[u] + rdis[v] - 2 * rdis[lca(u, v)];
}
long long dfs2(int v) {
dp[v][1] = dp[v][2] = 1e15;
if (important[v]) {
dp[v][tp[v]] = 0;
}
long long ret = LLONG_MAX;
for (auto [u, w] : vg[v]) {
ret = min(ret, dfs2(u));
dp[v][1] = min(dp[v][1], dp[u][1] + w);
dp[v][2] = min(dp[v][2], dp[u][2] + w);
}
ret = min(ret, dp[v][1] + dp[v][2]);
return ret;
}
void Init(int N, int A[], int B[], int D[]) {
for (int i = 0; i < N; ++i) {
g[A[i]].emplace_back(B[i], D[i]);
g[B[i]].emplace_back(A[i], D[i]);
}
dfs(0, 0);
}
long long Query(int S, int X[], int T, int Y[]) {
vector<int> factories;
for (int i = 0; i < S; ++i) {
factories.push_back(X[i]);
important[X[i]] = true;
tp[X[i]] = 1;
}
bool rep = false;
for (int i = 0; i < T; ++i) {
factories.push_back(Y[i]);
important[Y[i]] = true;
if (tp[Y[i]]) {
rep = true;
}
tp[Y[i]] = 2;
}
sort(factories.begin(), factories.end(), [&](int u, int v) {
return in[u] < in[v];
});
factories.push_back(0);
for (int i = 0; i < S + T - 1; ++i) {
factories.push_back(lca(factories[i], factories[i + 1]));
}
sort(factories.begin(), factories.end(), [&](int u, int v) {
return in[u] < in[v];
});
factories.resize(unique(factories.begin(), factories.end()) - factories.begin());
stack<int> stk;
assert(factories[0] == 0);
stk.push(0);
for (int i = 1; i < (int) factories.size(); ++i) {
int cur = factories[i];
while (!is_parent(cur, stk.top())) {
stk.pop();
}
vg[stk.top()].emplace_back(cur, dis(cur, stk.top()));
stk.push(cur);
}
long long ans = 0;
if (!rep) {
ans = dfs2(0);
}
for (int i = 0; i < S; ++i) {
tp[X[i]] = 0;
important[X[i]] = false;
}
for (int i = 0; i < T; ++i) {
tp[Y[i]] = 0;
important[Y[i]] = false;
}
for (int i = 0; i < (int) factories.size(); ++i) {
vg[factories[i]].clear();
}
return ans;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
27 ms |
93016 KB |
Output is correct |
2 |
Correct |
721 ms |
97520 KB |
Output is correct |
3 |
Correct |
818 ms |
97680 KB |
Output is correct |
4 |
Incorrect |
752 ms |
97784 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
14 ms |
93016 KB |
Output is correct |
2 |
Correct |
1616 ms |
135276 KB |
Output is correct |
3 |
Incorrect |
2712 ms |
140684 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
27 ms |
93016 KB |
Output is correct |
2 |
Correct |
721 ms |
97520 KB |
Output is correct |
3 |
Correct |
818 ms |
97680 KB |
Output is correct |
4 |
Incorrect |
752 ms |
97784 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |