이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
#include "factories.h"
//~ #include "grader.cpp"
using namespace std;
#define ll long long
const int NN = 500001;
const int K = 20;
const ll INF = (ll)1e18;
int mark[NN], s[NN], e[NN], dep[NN], p[NN][K], dfstime;
ll d[NN], ans;
vector <pair <int, ll> > v[NN], g[NN];
void dfs(int node, int pnode){
s[node] = dfstime++;
p[node][0] = pnode;
for(int i = 1 ; i < K ; i++){
p[node][i] = p[p[node][i - 1]][i - 1];
}
for(auto &i : v[node]){
if(i.first == pnode) continue;
d[i.first] = d[node] + i.second;
dep[i.first] = dep[node] + 1;
dfs(i.first, node);
}
e[node] = dfstime - 1;
}
bool inside(int x, int y){
return s[x] <= s[y] && e[y] <= e[x];
}
int lift(int x, int k){
for(int i = 0 ; i < K ; i++){
if((k >> i) & 1){
x = p[x][i];
}
}
return x;
}
int LCA(int x, int y){
if(dep[x] >= dep[y]) swap(x, y);
y = lift(y, dep[y] - dep[x]);
if(x == y) return x;
for(int i = 0 ; i < K ; i++){
if(p[x][i] != p[y][i]){
x = p[x][i];
y = p[y][i];
}
}
return p[x][0];
}
pair <ll, ll> solve(int node, int pnode){
pair <ll, ll> cur = make_pair(INF, INF);
if(mark[node] == 1) cur.first = 0;
if(mark[node] == 2) cur.second = 0;
for(auto &i : g[node]){
if(i.first == pnode) continue;
auto f = solve(i.first, node);
f.first += i.second;
f.second += i.second;
ans = min(ans, cur.first + f.second);
ans = min(ans, cur.second + f.first);
cur.first = min(cur.first, f.first);
cur.second = min(cur.second, f.second);
}
return cur;
}
void Init(int N, int A[], int B[], int D[]) {
for(int i = 0 ; i < N - 1 ; i++){
v[A[i]].push_back(make_pair(B[i], D[i]));
v[B[i]].push_back(make_pair(A[i], D[i]));
}
dfs(0, 0);
}
long long Query(int S, int X[], int T, int Y[]) {
vector <int> all;
for(int i = 0 ; i < S ; i++){
mark[X[i]] = 1;
all.push_back(X[i]);
}
for(int i = 0 ; i < T ; i++){
mark[Y[i]] = 2;
all.push_back(Y[i]);
}
sort(all.begin(), all.end(), [&](int l, int r){
return s[l] < s[r];
});
int all_sz = all.size();
for(int i = 0 ; i < all_sz ; i++){
all.push_back(LCA(all[i], all[(i + 1) % all.size()]));
}
sort(all.begin(), all.end());
all.erase(unique(all.begin(), all.end()), all.end());
sort(all.begin(), all.end(), [&](int l, int r){
return s[l] < s[r];
});
vector <int> cur;
for(auto &i : all){
while(cur.size() && !inside(cur.back(), i)) cur.pop_back();
if(cur.size()){
ll cost = d[i] - d[cur.back()];
g[cur.back()].push_back(make_pair(i, cost));
g[i].push_back(make_pair(cur.back(), cost));
}
cur.push_back(i);
}
ans = INF;
solve(all[0], 0);
for(auto &i : all){
g[i].clear();
mark[i] = 0;
}
assert(ans < INF);
return ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |