| # | 제출 시각 | 아이디 | 문제 | 언어 | 결과 | 실행 시간 | 메모리 |
|---|---|---|---|---|---|---|---|
| 1304518 | disfyy | 경주 (Race) (IOI11_race) | C++20 | 0 ms | 0 KiB |
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
const ll N = 5e5 + 5;
const ll NN = 5e3 + 5;
const ll INF = 2e18 + 5;
const ll MOD = 1e9 + 7;
const ll P = 53;
const double eps = 1e-9;
ll a[N], up[N][23], dp[N], sum[N][23], tin[N], tout[N], timer = 0;
vector<pair<ll, ll>> g[N];
void dfs(ll v, ll p = 0) {
dp[v] = dp[p] + 1;
tin[v] = ++timer;
for(int i = 1; i <= 22; i++) {
up[v][i] = up[up[v][i - 1]][i - 1];
sum[v][i] = sum[v][i - 1] + sum[up[v][i - 1]][i - 1];
}
for(auto [to, w] : g[v]) {
if(to != p) {
up[to][0] = v;
sum[to][0] = w;
dfs(to, v);
}
}
tout[v] = timer;
}
bool check(ll u, ll v) {
if(tin[u] <= tin[v] && tout[v] <= tout[u]) {
return true;
}
return false;
}
ll lca(ll u, ll v) {
if(check(u, v)) {
return u;
} else if(check(v, u)) {
return v;
}
for(int i = 22; i >= 0; i--) {
if(!check(up[u][i], v)) {
u = up[u][i];
}
}
return up[u][0];
}
ll calc(ll u, ll need) {
ll ans = 0;
for(int i = 22; i >= 0; i--) {
bool ok = (need & (1LL << i));
if(ok) {
ans += sum[u][i];
u = up[u][i];
}
}
return ans;
}
void solve() {
ll n, k;
cin >> n >> k;
vector<pair<pair<ll, ll>, ll>> d(n);
for(int i = 1; i <= n - 1; i++) {
ll u, v;
cin >> u >> v;
u++, v++;
d[i - 1].first.first = u, d[i - 1].first.second = v;
}
for(int i = 1; i <= n - 1; i++) {
ll x;
cin >> x;
d[i - 1].second = x;
}
for(auto x : d) {
g[x.first.first].push_back({x.first.second, x.second});
g[x.first.second].push_back({x.first.first, x.second});
}
up[1][0] = 1;
dfs(1);
ll mn = INF;
for(int i = 1; i <= n; i++) {
for(int j = 1; j <= n; j++) {
ll res = lca(i, j);
ll f = calc(i, dp[i] - dp[res]) + calc(j, dp[j] - dp[res]);
// cout << i << " " << j << " " << f << '\n';
if(f == k) {
mn = min(mn, dp[i] - dp[res] + dp[j] - dp[res]);
}
}
}
if(mn == INF) {
mn = -1;
}
cout << mn << '\n';
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(nullptr); cout.tie(nullptr);
int test = 1;
// cin >> test;
for(int i = 1; i <= test; i++) {
solve();
}
return 0;
}
