# | 제출 시각 | 아이디 | 문제 | 언어 | 결과 | 실행 시간 | 메모리 |
---|---|---|---|---|---|---|---|
170840 | ZwariowanyMarcin | 공장들 (JOI14_factories) | C++14 | 0 ms | 0 KiB |
이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include "factories.h"
#include <bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define ss(x) (int) x.size()
#define pb push_back
#define ll long long
#define cat(x) cerr << #x << " = " << x << endl
#define FOR(i, n) for(int i = 0; i < n; ++i)
using namespace std;
const int nax = 5e5 + 111;
int n;
vector <pair<int,int>> v[nax];
int par[nax];
int kil[nax];
int nn;
int siz[nax];
int h[nax];
int jump[nax][20];
ll H[nax];
void prep(int u, int p) {
jump[u][0] = p;
h[u] = h[p] + 1;
for(auto it : v[u])
if(it.fi != p) {
H[it.fi] = H[u] + it.se;
prep(it.fi, u);
}
}
int lca(int x, int y) {
if(h[x] < h[y])
swap(x, y);
for(int i = 19; 0 <= i; --i)
if(h[y] <= h[x] - (1 << i))
x = jump[x][i];
if(x == y)
return x;
for(int i = 19; 0 <= i; --i)
if(jump[x][i] != jump[y][i]) {
x = jump[x][i];
y = jump[y][i];
}
return jump[x][0];
}
ll dis(int x, int y) {
return H[x] + H[y] - 2 * H[lca(x, y)];
}
void dfs(int u, int p) {
nn++;
siz[u] = 1;
for(auto it : v[u])
if(it.fi != p && !kil[it.fi]) {
dfs(it.fi, u);
siz[u] += siz[it.fi];
}
}
int daj(int u, int p) {
for(auto it : v[u])
if(it.fi != p && !kil[it.fi] && nn <= 2 * siz[it.fi])
return daj(it.fi, u);
return u;
}
void decomp(int u, int p) {
nn = 0;
dfs(u, 0);
int C = daj(u, 0);
par[C] = p;
kil[C] = 1;
for(auto it : v[C])
if(!kil[it.fi])
decomp(it.fi, C);
}
ll naj[nax];
void Init(int N, vector <int> a, vector <int> b, vector <int> c) {
n = N;
for(auto &it : a)
it++;
for(auto &it : b)
it++;
for(int i = 0; i < n - 1; ++i) {
v[a[i]].pb(mp(b[i], c[i]));
v[b[i]].pb(mp(a[i], c[i]));
}
prep(1, 0);
decomp(1, 0);
for(int j = 1; j <= 19; ++j)
for(int i = 1; i <= n; ++i)
jump[i][j] = jump[jump[i][j - 1]][j - 1];
for(int i = 1; i <= n; ++i)
naj[i] = 1e18;
}
ll Query(int s, vector <int> x, int t, vector <int> y) {
for(auto &it : x)
it++;
for(auto &it : y)
it++;
ll res = 1e18;
for(auto it : x) {
int node = it;
while(node != 0) {
naj[node] = min(naj[node], dis(it, node));
node = par[node];
}
}
for(auto it : y) {
int node = it;
while(node != 0) {
res = min(res, naj[node] + dis(node, it));
node = par[node];
}
}
for(auto it : x) {
int node = it;
while(node != 0) {
naj[node] = 1e18;
node = par[node];
}
}
return res;
}