# | Time | Username | Problem | Language | Result | Execution time | Memory |
---|---|---|---|---|---|---|---|
1255922 | kamrad | Factories (JOI14_factories) | C++20 | 0 ms | 0 KiB |
#include <bits/stdc++.h>
#include "factories.h"
using namespace std;
//#pragma GCC optimize("Ofast,unroll-loops")
//#pragma GCC target("avx2,popcnt,lzcnt,abm,bmi,bmi2,fma,tune=native")
using ll = long long;
using ld = long double;
using pii = pair<ll, ll>;
using pll = pair<ll, ll>;
using pi3 = pair<pii, ll>;
#define IOS ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);
#define F first
#define S second
#define sz(x) x.size()
#define all(x) x.begin(), x.end()
#define pb push_back
#define minr(a, b) a = min(a, b);
#define maxr(a, b) a = max(a, b);
#define shit cout << "shit\n" << flush;
#define tl while(1&1) continue;
#define rand(l, r) uniform_ll_distribution<ll64_t>(l,r)(rng)
random_device device; default_random_engine rng(device());
const ll Mod = 1e9 + 7; //998244353;
const ll LG = 30;
const ll SQ = 500;
const ll Inf = 2e18 + 10;
const ll maxN = 5e5 + 10;
ll n;
ll timer = 1;
ll h[maxN];
ll d[maxN];
ll st[maxN];
ll ft[maxN];
ll par[maxN][LG];
vector <pii> neighb[maxN];
struct SegTree {
struct Node {
ll mn;
Node() {
mn = Inf;
}
} t[maxN<<2];
void update(ll id, ll L, ll R, ll idx, ll val) {
if(L+1 == R) {
t[id].mn = val;
return;
}
ll mid = (L+R)>>1;
if(idx < mid)
update(2*id+0, L, mid, idx, val);
else
update(2*id+1, mid, R, idx, val);
t[id].mn = min(t[2*id+0].mn, t[2*id+1].mn);
}
ll get(ll id, ll L, ll R, ll l, ll r) {
if(L == l and R == r)
return t[id].mn;
ll ret = Inf;
ll mid = (L+R)>>1;
if(l < mid)
minr(ret, get(2*id+0, L, mid, l, min(mid, r)));
if(r > mid)
minr(ret, get(2*id+1, mid, R, max(l, mid), r));
return ret;
}
} s, t;
void dfs(ll u) {
st[u] = timer++;
for(ll i = 1; i < LG; i++)
par[u][i] = par[par[u][i-1]][i-1];
for(auto [v, w] : neighb[u]) {
if(v != par[u][0]) {
d[v] = d[u]+1;
h[v] = h[u]+w;
par[v][0] = u;
dfs(v);
}
}
ft[u] = timer;
}
ll LCA(ll u, ll v) {
if(h[u] > h[v])
swap(u, v);
for(ll i = LG-1; i >= 0; i--)
if(h[par[v][i]] >= h[u])
v = par[v][i];
if(u == v)
return v;
for(ll i = LG-1; i >= 0; i--) {
if(par[u][i] != par[v][i]) {
u = par[u][i];
v = par[v][i];
}
}
return par[u][0];
}
void Init(ll N, ll A[], ll B[], ll D[]) {
n = N;
for(ll i = 0; i < n-1; i++) {
neighb[A[i]+1].pb({B[i]+1, D[i]});
neighb[B[i]+1].pb({A[i]+1, D[i]});
}
dfs(1);
}
bool cmp(ll u, ll v) {
return st[u] < st[v];
}
long long Query(ll S, ll X[], ll T, ll Y[]) {
set <ll> virt;
vector <ll> val;
for(ll i = 0; i < S; i++) {
val.pb(X[i]+1);
s.update(1, 1, n+1, st[X[i]+1], d[X[i]+1]);
}
for(ll i = 0; i < T; i++) {
val.pb(Y[i]+1);
t.update(1, 1, n+1, st[Y[i]+1], d[Y[i]+1]);
}
sort(all(val), cmp);
for(ll i = 0; i < sz(val)-1; i++)
virt.insert(LCA(val[i], val[i+1]));
ll ans = Inf;
for(auto x : virt) {
ll tmp = s.get(1, 1, n+1, st[x], ft[x]);
tmp += t.get(1, 1, n+1, st[x], ft[x]);
tmp -= 2*d[x];
minr(ans, tmp);
}
for(ll i = 0; i < S; i++)
s.update(1, 1, n+1, st[X[i]+1], Inf);
for(ll i = 0; i < T; i++)
t.update(1, 1, n+1, st[Y[i]+1], Inf);
return ans;
}