#include "deliveries.h"
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 100005;
vector<pair<int, int>> adj[MAXN];
ll delivery_counts[MAXN];
int parent[MAXN], subtree_size[MAXN], depth[MAXN];
int heavy_child[MAXN], path_head[MAXN], pos_in_tree[MAXN];
int edge_weight_to_parent[MAXN];
int timer, num_cities;
// Segment Tree 1: Manages the number of deliveries in subtrees
struct SubtreeSumTree {
ll tree[MAXN << 2], lazy[MAXN << 2];
void push(int node, int l, int r) {
if (lazy[node] != 0) {
tree[node] += (ll)(r - l + 1) * lazy[node];
if (l != r) {
lazy[node << 1] += lazy[node];
lazy[node << 1 | 1] += lazy[node];
}
lazy[node] = 0;
}
}
void update(int node, int l, int r, int ql, int qr, int val) {
push(node, l, r);
if (ql > r || qr < l) return;
if (ql <= l && r <= qr) {
lazy[node] += val;
push(node, l, r);
return;
}
int mid = (l + r) >> 1;
update(node << 1, l, mid, ql, qr, val);
update(node << 1 | 1, mid + 1, r, ql, qr, val);
tree[node] = tree[node << 1] + tree[node << 1 | 1];
}
ll query(int node, int l, int r, int ql, int qr) {
push(node, l, r);
if (ql > r || qr < l) return 0;
if (ql <= l && r <= qr) return tree[node];
int mid = (l + r) >> 1;
return query(node << 1, l, mid, ql, qr) + query(node << 1 | 1, mid + 1, r, ql, qr);
}
} sum_st;
// Segment Tree 2: Manages the weighted contribution (Weight * SubtreeDeliveries)
struct WeightedEdgeTree {
ll tree[MAXN << 2], lazy[MAXN << 2], base_weights[MAXN << 2];
void build(int node, int l, int r, const vector<int>& weights) {
if (l == r) {
base_weights[node] = weights[l];
return;
}
int mid = (l + r) >> 1;
build(node << 1, l, mid, weights);
build(node << 1 | 1, mid + 1, r, weights);
base_weights[node] = base_weights[node << 1] + base_weights[node << 1 | 1];
}
void push(int node, int l, int r) {
if (lazy[node] != 0) {
tree[node] += base_weights[node] * lazy[node];
if (l != r) {
lazy[node << 1] += lazy[node];
lazy[node << 1 | 1] += lazy[node];
}
lazy[node] = 0;
}
}
void update(int node, int l, int r, int ql, int qr, int val) {
push(node, l, r);
if (ql > r || qr < l) return;
if (ql <= l && r <= qr) {
lazy[node] += val;
push(node, l, r);
return;
}
int mid = (l + r) >> 1;
update(node << 1, l, mid, ql, qr, val);
update(node << 1 | 1, mid + 1, r, ql, qr, val);
tree[node] = tree[node << 1] + tree[node << 1 | 1];
}
ll query_all() { return tree[1]; }
} weight_st;
// --- HLD Implementation ---
void dfs_size(int u, int p, int d) {
subtree_size[u] = 1;
parent[u] = p;
depth[u] = d;
heavy_child[u] = -1;
for (auto& edge : adj[u]) {
int v = edge.first;
if (v != p) {
edge_weight_to_parent[v] = edge.second;
dfs_size(v, u, d + 1);
subtree_size[u] += subtree_size[v];
if (heavy_child[u] == -1 || subtree_size[v] > subtree_size[heavy_child[u]]) {
heavy_child[u] = v;
}
}
}
}
void dfs_hld(int u, int h) {
path_head[u] = h;
pos_in_tree[u] = timer++;
if (heavy_child[u] != -1) dfs_hld(heavy_child[u], h);
for (auto& edge : adj[u]) {
int v = edge.first;
if (v != parent[u] && v != heavy_child[u]) {
dfs_hld(v, v);
}
}
}
void update_path(int u, int delta) {
while (u != -1) {
int head = path_head[u];
sum_st.update(1, 0, num_cities - 1, pos_in_tree[head], pos_in_tree[u], delta);
weight_st.update(1, 0, num_cities - 1, pos_in_tree[head], pos_in_tree[u], delta);
u = parent[head];
}
}
// Corrects the sum if a subtree contains > 50% of total deliveries
ll find_unbalanced_correction(int u) {
ll total = sum_st.query(1, 0, num_cities - 1, 0, 0);
for (auto& edge : adj[u]) {
int v = edge.first;
if (v != parent[u]) {
ll sub_sum = sum_st.query(1, 0, num_cities - 1, pos_in_tree[v], pos_in_tree[v]);
if (sub_sum * 2 > total) {
return find_unbalanced_correction(v) + 2LL * edge.second * (total - 2 * sub_sum + 1);
}
}
}
return 0;
}
void init(int N, vector<int> U, vector<int> V, vector<int> W, vector<int> T) {
num_cities = N;
timer = 0;
for (int i = 0; i < N - 1; i++) {
adj[U[i]].push_back({V[i], W[i]});
adj[V[i]].push_back({U[i], W[i]});
}
dfs_size(0, -1, 0);
dfs_hld(0, 0);
vector<int> mapped_weights(N);
for (int i = 0; i < N; i++) {
delivery_counts[i] = T[i];
mapped_weights[pos_in_tree[i]] = edge_weight_to_parent[i];
}
weight_st.build(1, 0, N - 1, mapped_weights);
for (int i = 0; i < N; i++) {
if (T[i] > 0) update_path(i, T[i]);
}
}
ll max_time(int city_index, int new_delivery_count) {
int delta = new_delivery_count - delivery_counts[city_index];
delivery_counts[city_index] = new_delivery_count;
update_path(city_index, delta);
return 2LL * weight_st.query_all() + find_unbalanced_correction(0);
}