#include "deliveries.h"
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 100005;
vector<pair<int, int>> adj[MAXN];
ll delivery_counts[MAXN];
int parent[MAXN], subtree_size[MAXN], depth_from_root[MAXN];
int heavy_child[MAXN], path_head[MAXN], pos_in_tree[MAXN], rev_pos[MAXN];
int edge_weight_to_parent[MAXN];
int timer, num_cities;
bool is_line_graph = true;
// Segment Tree 1: Subtree delivery sums
struct SubtreeSumTree {
ll tree[MAXN << 2], lazy[MAXN << 2];
void push(int node, int l, int r) {
if (lazy[node] != 0) {
tree[node] += (ll)(r - l + 1) * lazy[node];
if (l != r) {
lazy[node << 1] += lazy[node];
lazy[node << 1 | 1] += lazy[node];
}
lazy[node] = 0;
}
}
void update(int node, int l, int r, int ql, int qr, int val) {
push(node, l, r);
if (ql > r || qr < l) return;
if (ql <= l && r <= qr) {
lazy[node] += val;
push(node, l, r);
return;
}
int mid = (l + r) >> 1;
update(node << 1, l, mid, ql, qr, val);
update(node << 1 | 1, mid + 1, r, ql, qr, val);
tree[node] = tree[node << 1] + tree[node << 1 | 1];
}
ll query(int node, int l, int r, int ql, int qr) {
push(node, l, r);
if (ql > r || qr < l) return 0;
if (ql <= l && r <= qr) return tree[node];
int mid = (l + r) >> 1;
return query(node << 1, l, mid, ql, qr) + query(node << 1 | 1, mid + 1, r, ql, qr);
}
} sum_st;
// Segment Tree 2: Weighted sums (Weight * SubtreeDeliveries)
struct WeightedEdgeTree {
ll tree[MAXN << 2], lazy[MAXN << 2], base_weights[MAXN << 2];
void build(int node, int l, int r, const vector<int>& weights) {
if (l == r) {
base_weights[node] = weights[l];
return;
}
int mid = (l + r) >> 1;
build(node << 1, l, mid, weights);
build(node << 1 | 1, mid + 1, r, weights);
base_weights[node] = base_weights[node << 1] + base_weights[node << 1 | 1];
}
void push(int node, int l, int r) {
if (lazy[node] != 0) {
tree[node] += base_weights[node] * lazy[node];
if (l != r) {
lazy[node << 1] += lazy[node];
lazy[node << 1 | 1] += lazy[node];
}
lazy[node] = 0;
}
}
void update(int node, int l, int r, int ql, int qr, int val) {
push(node, l, r);
if (ql > r || qr < l) return;
if (ql <= l && r <= qr) {
lazy[node] += val;
push(node, l, r);
return;
}
int mid = (l + r) >> 1;
update(node << 1, l, mid, ql, qr, val);
update(node << 1 | 1, mid + 1, r, ql, qr, val);
tree[node] = tree[node << 1] + tree[node << 1 | 1];
}
ll query(int node, int l, int r, int ql, int qr) {
push(node, l, r);
if (ql > r || qr < l) return 0;
if (ql <= l && r <= qr) return tree[node];
int mid = (l + r) >> 1;
return query(node << 1, l, mid, ql, qr) + query(node << 1 | 1, mid + 1, r, ql, qr);
}
} weight_st;
// HLD Logic
void dfs_size(int u, int p, int d, ll dist) {
subtree_size[u] = 1; parent[u] = p;
depth_from_root[u] = dist; heavy_child[u] = -1;
for (auto& edge : adj[u]) {
int v = edge.first;
if (v != p) {
edge_weight_to_parent[v] = edge.second;
dfs_size(v, u, d + 1, dist + edge.second);
subtree_size[u] += subtree_size[v];
if (heavy_child[u] == -1 || subtree_size[v] > subtree_size[heavy_child[u]]) heavy_child[u] = v;
}
}
}
void dfs_hld(int u, int h) {
path_head[u] = h; pos_in_tree[u] = timer++; rev_pos[pos_in_tree[u]] = u;
if (heavy_child[u] != -1) dfs_hld(heavy_child[u], h);
for (auto& edge : adj[u]) {
int v = edge.first;
if (v != parent[u] && v != heavy_child[u]) dfs_hld(v, v);
}
}
void update_path(int u, int delta) {
while (u != -1) {
int head = path_head[u];
sum_st.update(1, 0, num_cities - 1, pos_in_tree[head], pos_in_tree[u], delta);
weight_st.update(1, 0, num_cities - 1, pos_in_tree[head], pos_in_tree[u], delta);
u = parent[head];
}
}
int find_centroid_recursive(int u, ll total) {
for (auto& edge : adj[u]) {
int v = edge.first;
if (v != parent[u]) {
ll sub_sum = sum_st.query(1, 0, num_cities - 1, pos_in_tree[v], pos_in_tree[v]);
if (sub_sum * 2 > total) return find_centroid_recursive(v, total);
}
}
return u;
}
void init(int N, vector<int> U, vector<int> V, vector<int> W, vector<int> T) {
num_cities = N; timer = 0;
for (int i = 0; i < N - 1; i++) {
adj[U[i]].push_back({V[i], W[i]});
adj[V[i]].push_back({U[i], W[i]});
if (U[i] != i || V[i] != i + 1) is_line_graph = false;
}
dfs_size(0, -1, 0, 0);
dfs_hld(0, 0);
vector<int> weights(N);
for (int i = 0; i < N; i++) {
delivery_counts[i] = T[i];
weights[pos_in_tree[i]] = edge_weight_to_parent[i];
}
weight_st.build(1, 0, N - 1, weights);
for (int i = 0; i < N; i++) if (T[i] > 0) update_path(i, T[i]);
}
ll max_time(int city_index, int new_delivery_count) {
int delta = new_delivery_count - delivery_counts[city_index];
delivery_counts[city_index] = new_delivery_count;
update_path(city_index, delta);
ll total_T = sum_st.query(1, 0, num_cities - 1, 0, 0);
int centroid = 0;
if (is_line_graph) {
int l = 0, r = num_cities - 1;
while (l <= r) {
int mid = (l + r) / 2;
if (sum_st.query(1, 0, num_cities - 1, mid, mid) * 2 > total_T) {
centroid = rev_pos[mid];
l = mid + 1;
} else r = mid - 1;
}
} else {
centroid = find_centroid_recursive(0, total_T);
}
ll result = 2LL * weight_st.query(1, 0, num_cities - 1, 0, num_cities - 1);
result += 2LL * depth_from_root[centroid] * (1 + total_T);
int u = centroid;
while (u != -1) {
int head = path_head[u];
result -= 4LL * weight_st.query(1, 0, num_cities - 1, pos_in_tree[head], pos_in_tree[u]);
u = parent[head];
}
return result;
}