이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define all(x) x.begin(), x.end()
using namespace std;
template <class X, class Y> bool minimize(X &a, Y b) {
if (a > b) return a = b, true;
return false;
}
template <class X, class Y> bool maximize(X &a, Y b) {
if (a < b) return a = b, true;
return false;
}
const int N = 1e5 + 7;
const int LOG = 20;
int n, num_shop, q, ec;
int tin[N];
int tout[N];
bool shop[N];
int depth[N];
long long dp[N];
long long dist[N];
int par[N][LOG + 1];
long long jump[N][LOG + 1];
pair<int, int> e[N];
vector<pair<int, int>> adj[N];
signed main() {
cin.tie(0)->sync_with_stdio(0);
cin >> n >> num_shop >> q >> ec;
for (int i = 1; i < n; i ++) {
int u, v, w; cin >> u >> v >> w;
adj[u].emplace_back(v, w);
adj[v].emplace_back(u, w);
e[i] = mp(u, v);
}
for (int i = 1; i <= num_shop; i ++) {
int u; cin >> u;
shop[u] = true;
}
int timer = 0;
const long long inf = 1e18 + 7;
function<void(int, int)> dfs_dist = [&](int u, int p) {
if (!shop[u]) {
dp[u] = inf;
} else {
dp[u] = dist[u];
}
par[u][0] = p;
for (int i = 1; i <= LOG; i ++) {
par[u][i] = par[par[u][i - 1]][i - 1];
}
tin[u] = ++ timer;
for (pair<int, int> v: adj[u]) if (v.fi != p) {
dist[v.fi] = dist[u] + v.se;
depth[v.fi] = depth[u] + 1;
dfs_dist(v.fi, u);
minimize(dp[u], dp[v.fi]);
}
tout[u] = timer;
};
function<void(int, int)> dfs = [&](int u, int p) {
if (dp[u] != inf) {
dp[u] -= 2 * dist[u];
}
jump[u][0] = dp[u];
for (int i = 1; i <= LOG; i ++) {
jump[u][i] = min(jump[u][i - 1], jump[par[u][i - 1]][i - 1]);
}
for (pair<int, int> v: adj[u]) if (v.fi != p) {
dfs(v.fi, u);
}
};
auto anc = [&](int u, int v) -> bool {
return tin[u] <= tin[v] && tout[u] >= tout[v];
};
dfs_dist(ec, ec);
dfs(ec, ec);
while (q --) {
int i, u; cin >> i >> u;
if (dist[e[i].fi] < dist[e[i].se]) {
swap(e[i].fi, e[i].se);
}
int p = e[i].fi;
if (!anc(p, u)) {
cout << "escaped\n";
} else {
long long res = inf;
int d = depth[u] - depth[p] + 1;
int v = u;
for (int j = LOG; j >= 0; j --) {
if (d >> j & 1) {
minimize(res, jump[u][j]);
u = par[u][j];
}
}
if (res >= (long long) 1e16) {
cout << "oo\n";
} else {
cout << res + dist[v] << '\n';
}
}
}
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |