#include <algorithm>
#include <iostream>
#include <vector>
class heavy_light {
struct path {
int u, v;
int lca;
long long value;
};
enum class chain_value_type {
children,
vertex
};
const int UNCHECKED = -1;
const int ROOT = 0; // Standard root
int num_vertices;
std::vector<std::vector<int>> adj_list;
std::vector<int> parent; // Parent of each vertex in rooted tree.
std::vector<int> depth; // Depth of each vertex.
std::vector<int> dfs_start; // Time when a vertex's DFS processing starts. Used to check for Ancestors
std::vector<int> dfs_end; // Time when a vertex's DFS processing ends. Used to check for Ancestors.
std::vector<int> subtree_size; // Number of vertices in subtree rooted at each vertex.
std::vector<std::vector<int>> chains; // The i-th vector is the i-th chain
std::vector<int> chain_index; // Which chain is a vertex part of? (Starts from 0)
std::vector<int> position_in_chain; // How deep is a vertex in the chain? (0 is closest to the root)
// DP Variables
std::vector<std::vector<path>> lca_paths; // The i-th vector represents all paths that has i as the LCA
std::vector<std::vector<long long>> cumulative_chain_children_memo; // The j-th index of the i-th vector represents the sum of the children's DP values for the j-th vertex and beyond within the i-th chain
std::vector<std::vector<long long>> cumulative_chain_vertex_memo; // Exactly the same as above, but stores the vertex's DP value instead of the children's
void dfs(int u, int& time) {
dfs_start[u] = time;
subtree_size[u] = 1;
for (int v : adj_list[u]) {
if (v == parent[u]) {
continue;
}
parent[v] = u;
depth[v] = depth[u] + 1;
++time;
dfs(v, time);
subtree_size[u] += subtree_size[v];
}
dfs_end[u] = time;
++time;
}
void make_rooted_tree() {
parent[ROOT] = ROOT;
depth[ROOT] = 0;
int time = 0;
dfs(ROOT, time);
}
void start_new_chain(int head) {
int new_chain_id = chains.size();
chains.push_back({head});
chain_index[head] = new_chain_id;
position_in_chain[head] = 0;
}
void recursive_construct_chain(int u) {
int next_in_chain = UNCHECKED;
for (int v : adj_list[u]) {
if (v == parent[u]) {
continue;
}
if (next_in_chain == UNCHECKED || subtree_size[next_in_chain] < subtree_size[v]) {
next_in_chain = v;
}
}
// Case: No children
if (next_in_chain == UNCHECKED) {
return;
}
// Set up next vertex in chain
chains[chain_index[u]].push_back(next_in_chain);
chain_index[next_in_chain] = chain_index[u];
position_in_chain[next_in_chain] = position_in_chain[u] + 1;
recursive_construct_chain(next_in_chain);
// Start new chains in all other children
for (int v : adj_list[u]) {
if (v == parent[u] || v == next_in_chain) {
continue;
}
start_new_chain(v);
recursive_construct_chain(v);
}
}
void construct_chains() {
start_new_chain(ROOT);
recursive_construct_chain(ROOT);
}
// Returns true if u is an ancestor of v
bool is_ancestor(int u, int v) {
return dfs_start[u] <= dfs_start[v] && dfs_end[v] <= dfs_end[u];
}
int find_lca(int u, int v) {
// We don't really need HLD to find the LCA.
// This can be done in O(1) per query using a Sparse Table.
// (But I didn't know at the time, so oh well)
// Step 1: Find chain containing LCA
int u_chain_index = chain_index[u];
int u_curr = chains[u_chain_index][0];
while (!is_ancestor(u_curr, v)) {
u_curr = parent[u_curr]; // Move to parent chain
u_chain_index = chain_index[u_curr];
u_curr = chains[u_chain_index][0]; // Move to head of chain
}
int lca_chain_index = chain_index[u_curr];
int low = 0;
int high = (int) (chains[lca_chain_index].size()) - 1;
int lca = u_curr;
// Step 2: Perform Binary Search on the chain
while (low <= high) {
int mid = low + (high - low) / 2;
int vertex = chains[lca_chain_index][mid];
if (is_ancestor(vertex, v) && is_ancestor(vertex, u)) {
lca = vertex;
low = mid + 1;
} else {
high = mid - 1;
}
}
return lca;
}
long long chain_sum(chain_value_type type, int chain_index, int chain_pos_start, int chain_pos_end) {
auto &cumulative_memo = (type == chain_value_type::children
? cumulative_chain_children_memo
: cumulative_chain_vertex_memo);
auto &chain = cumulative_memo[chain_index];
return chain[chain_pos_start] - chain[chain_pos_end + 1];
}
long long process_path(const path& p) {
/*
Here's the idea behind how this works.
Consider the following tree, and the path from A1 to A5 (i.e. A1-A2-A3-A4-A5):
A3
/ \
A2 A4
/ \ / \
A1 X3 X4 A5
/ | / | \
X1 X2 X5 X6 X7
Let DP[V] represent the highest possible score (i.e. number of votes) obtainable
among all the paths that exist in the subtree rooted at the vertex V.
If we decide to take the path from A1 to A5, then the maximum remaining score we can
obtain is the sum of the DP values of the remaining untouched children. In this case,
that would be
score(A1-A2-A3-A4-A5) + DP[X1] + DP[X2] + DP[X3] + DP[X4] + DP[X5] + DP[X6] + DP[X7]
The DP terms can simply be added together because all the subtrees rooted at those
vertices are disjoint.
This is still hard to deal with, since there may be many children to consider, so
let's try to represent it in a different way.
Let children(V) represent the set of children of a vertex. For example, children(A2)
returns {A1, X3}. Also for simplicity, let DP[{U1, ... , Uk}] = DP[U1] + ... + DP[Uk].
Now, we can represent the expression as such:
score(A1-A2-A3-A4-A5)
+ DP[children(A1)] + DP[children(A2)] + DP[children(A3)] + DP[children(A4)] + DP[children(A5)]
- DP[A1] - DP[A2] - DP[A4] - DP[A5]
Or, to make the expression more general, for a path P = A-...-L-...-B, where L is the
LCA of vertices A and B:
score(P) + sum(DP[children(v)] for v in P) - sum(DP[v] for v in P excluding L)
*/
int lca_chain_index = chain_index[p.lca];
int lca_pos_in_chain = position_in_chain[p.lca];
long long answer = 0;
int endpoints[] = {p.u, p.v};
for (int vertex : endpoints) {
// Step 1: Climb up from curr to LCA
long long path_value = 0;
// Step 1a: Climb until the same chain as LCA
int curr = vertex;
while (chain_index[curr] != lca_chain_index) {
int curr_chain_index = chain_index[curr];
int curr_pos_in_chain = position_in_chain[curr];
int curr_chain_head = chains[curr_chain_index][0];
path_value += chain_sum(chain_value_type::children, curr_chain_index, 0, curr_pos_in_chain);
// cumulative_chain_children_memo[curr_chain_index][0] - cumulative_chain_children_memo[curr_chain_index][curr_pos_in_chain + 1];
path_value -= chain_sum(chain_value_type::vertex, curr_chain_index, 0, curr_pos_in_chain);
// path_value -= cumulative_chain_vertex_memo[curr_chain_index][0] - cumulative_chain_vertex_memo[curr_chain_index][curr_pos_in_chain + 1];
curr = parent[curr_chain_head];
}
// Step 1b: Climb from curr to LCA
{
int curr_pos_in_chain = position_in_chain[curr];
path_value += chain_sum(chain_value_type::children, lca_chain_index, lca_pos_in_chain + 1, curr_pos_in_chain);
path_value -= chain_sum(chain_value_type::vertex, lca_chain_index, lca_pos_in_chain + 1, curr_pos_in_chain);
// path_value += cumulative_chain_children_memo[lca_chain_index][lca_pos_in_chain + 1] - cumulative_chain_children_memo[lca_chain_index][curr_pos_in_chain + 1];
// path_value -= cumulative_chain_vertex_memo[lca_chain_index][lca_pos_in_chain + 1] - cumulative_chain_vertex_memo[lca_chain_index][curr_pos_in_chain + 1];
}
answer += path_value;
}
answer += chain_sum(chain_value_type::children, lca_chain_index, lca_pos_in_chain, lca_pos_in_chain) + p.value;
// answer += (cumulative_chain_children_memo[lca_chain_index][lca_pos_in_chain] - cumulative_chain_children_memo[lca_chain_index][lca_pos_in_chain + 1]) + p.value;
return answer;
}
long long bottom_up_dp(int u) {
// DP value of u represents the maximum value obtainable within subtree rooted at u.
// Step 1: Process all children first. Keep track of sum of DP values of children
long long children_dp = 0;
for (int v : adj_list[u]) {
if (v == parent[u]) {
continue;
}
children_dp += bottom_up_dp(v);
}
// Step 2: Update cumulative chain children memo
int u_chain_index = chain_index[u];
int u_pos_in_chain = position_in_chain[u];
cumulative_chain_children_memo[u_chain_index][u_pos_in_chain] = cumulative_chain_children_memo[u_chain_index][u_pos_in_chain + 1] + children_dp;
// Step 3: Process all paths to find DP value of u
long long u_dp_value = children_dp;
for (path& p : lca_paths[u]) {
u_dp_value = std::max(u_dp_value, process_path(p));
}
// Step 4: Update cumulative chain vertex memo
cumulative_chain_vertex_memo[u_chain_index][u_pos_in_chain] = cumulative_chain_vertex_memo[u_chain_index][u_pos_in_chain + 1] + u_dp_value;
// Step 5: Return
return u_dp_value;
}
public:
heavy_light(int _num_vertices) :
num_vertices(_num_vertices),
adj_list(num_vertices),
parent(num_vertices, UNCHECKED),
depth(num_vertices, UNCHECKED),
dfs_start(num_vertices, UNCHECKED),
dfs_end(num_vertices, UNCHECKED),
subtree_size(num_vertices, UNCHECKED),
chain_index(num_vertices, UNCHECKED),
position_in_chain(num_vertices, UNCHECKED),
lca_paths(num_vertices) {}
void add_edge(int u, int v) {
adj_list[u].push_back(v);
adj_list[v].push_back(u);
}
void decompose() {
make_rooted_tree();
construct_chains();
}
void print_details() {
for (int i = 0; i < (int) chains.size(); ++i) {
std::cout << "Chain #" << i << ": ";
for (int u : chains[i]) {
std::cout << u << "-";
}
std::cout << "\n";
}
std::cout << "\n";
for (int i = 0; i < num_vertices; ++i) {
std::cout << "Vertex #" << i
<< " - DFS times: " << dfs_start[i] << " to " << dfs_end[i]
<< ", Chain ID: " << chain_index[i]
<< ", Position in Chain: " << position_in_chain[i] << "\n";
}
}
void add_path(int u, int v, long long value) {
int lca = find_lca(u, v);
lca_paths[lca].push_back({u, v, lca, value});
}
long long get_answer() {
for (auto& chain : chains) {
cumulative_chain_children_memo.emplace_back(chain.size() + 1, 0);
cumulative_chain_vertex_memo.emplace_back(chain.size() + 1, 0);
}
return bottom_up_dp(ROOT);
}
};
int main() {
std::ios_base::sync_with_stdio(false);
std::cin.tie(nullptr);
int num_vertices;
std::cin >> num_vertices;
heavy_light hld(num_vertices);
for (int i = 0; i < num_vertices - 1; ++i) {
int u, v;
std::cin >> u >> v;
hld.add_edge(u - 1, v - 1); // 1-indexed vertices
}
hld.decompose();
int num_paths;
std::cin >> num_paths;
for (int i = 0; i < num_paths; ++i) {
int u, v;
long long value;
std::cin >> u >> v >> value;
hld.add_path(u - 1, v - 1, value); // 1-indexed vertices
}
std::cout << hld.get_answer() << "\n";
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
376 KB |
Output is correct |
2 |
Correct |
2 ms |
252 KB |
Output is correct |
3 |
Correct |
2 ms |
376 KB |
Output is correct |
4 |
Correct |
3 ms |
632 KB |
Output is correct |
5 |
Correct |
131 ms |
20140 KB |
Output is correct |
6 |
Correct |
62 ms |
21876 KB |
Output is correct |
7 |
Correct |
115 ms |
22316 KB |
Output is correct |
8 |
Correct |
121 ms |
28880 KB |
Output is correct |
9 |
Correct |
110 ms |
20544 KB |
Output is correct |
10 |
Correct |
127 ms |
28708 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
256 KB |
Output is correct |
2 |
Correct |
2 ms |
376 KB |
Output is correct |
3 |
Correct |
3 ms |
504 KB |
Output is correct |
4 |
Correct |
136 ms |
25300 KB |
Output is correct |
5 |
Correct |
138 ms |
25456 KB |
Output is correct |
6 |
Correct |
137 ms |
25328 KB |
Output is correct |
7 |
Correct |
136 ms |
25328 KB |
Output is correct |
8 |
Correct |
135 ms |
25300 KB |
Output is correct |
9 |
Correct |
134 ms |
25328 KB |
Output is correct |
10 |
Correct |
152 ms |
25328 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
256 KB |
Output is correct |
2 |
Correct |
2 ms |
376 KB |
Output is correct |
3 |
Correct |
3 ms |
504 KB |
Output is correct |
4 |
Correct |
136 ms |
25300 KB |
Output is correct |
5 |
Correct |
138 ms |
25456 KB |
Output is correct |
6 |
Correct |
137 ms |
25328 KB |
Output is correct |
7 |
Correct |
136 ms |
25328 KB |
Output is correct |
8 |
Correct |
135 ms |
25300 KB |
Output is correct |
9 |
Correct |
134 ms |
25328 KB |
Output is correct |
10 |
Correct |
152 ms |
25328 KB |
Output is correct |
11 |
Correct |
14 ms |
1400 KB |
Output is correct |
12 |
Correct |
144 ms |
25328 KB |
Output is correct |
13 |
Correct |
143 ms |
25392 KB |
Output is correct |
14 |
Correct |
141 ms |
25368 KB |
Output is correct |
15 |
Correct |
143 ms |
25340 KB |
Output is correct |
16 |
Correct |
142 ms |
25300 KB |
Output is correct |
17 |
Correct |
143 ms |
25324 KB |
Output is correct |
18 |
Correct |
138 ms |
25328 KB |
Output is correct |
19 |
Correct |
140 ms |
25328 KB |
Output is correct |
20 |
Correct |
139 ms |
25328 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
274 ms |
22976 KB |
Output is correct |
2 |
Correct |
135 ms |
25332 KB |
Output is correct |
3 |
Correct |
218 ms |
25292 KB |
Output is correct |
4 |
Correct |
207 ms |
31372 KB |
Output is correct |
5 |
Correct |
216 ms |
24620 KB |
Output is correct |
6 |
Correct |
207 ms |
31528 KB |
Output is correct |
7 |
Correct |
217 ms |
24364 KB |
Output is correct |
8 |
Correct |
286 ms |
23256 KB |
Output is correct |
9 |
Correct |
154 ms |
25328 KB |
Output is correct |
10 |
Correct |
215 ms |
23240 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
376 KB |
Output is correct |
2 |
Correct |
2 ms |
252 KB |
Output is correct |
3 |
Correct |
2 ms |
376 KB |
Output is correct |
4 |
Correct |
3 ms |
632 KB |
Output is correct |
5 |
Correct |
131 ms |
20140 KB |
Output is correct |
6 |
Correct |
62 ms |
21876 KB |
Output is correct |
7 |
Correct |
115 ms |
22316 KB |
Output is correct |
8 |
Correct |
121 ms |
28880 KB |
Output is correct |
9 |
Correct |
110 ms |
20544 KB |
Output is correct |
10 |
Correct |
127 ms |
28708 KB |
Output is correct |
11 |
Correct |
3 ms |
632 KB |
Output is correct |
12 |
Correct |
3 ms |
632 KB |
Output is correct |
13 |
Correct |
3 ms |
636 KB |
Output is correct |
14 |
Correct |
3 ms |
632 KB |
Output is correct |
15 |
Correct |
3 ms |
632 KB |
Output is correct |
16 |
Correct |
3 ms |
632 KB |
Output is correct |
17 |
Correct |
3 ms |
632 KB |
Output is correct |
18 |
Correct |
3 ms |
632 KB |
Output is correct |
19 |
Correct |
3 ms |
632 KB |
Output is correct |
20 |
Correct |
3 ms |
504 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
376 KB |
Output is correct |
2 |
Correct |
2 ms |
252 KB |
Output is correct |
3 |
Correct |
2 ms |
376 KB |
Output is correct |
4 |
Correct |
3 ms |
632 KB |
Output is correct |
5 |
Correct |
131 ms |
20140 KB |
Output is correct |
6 |
Correct |
62 ms |
21876 KB |
Output is correct |
7 |
Correct |
115 ms |
22316 KB |
Output is correct |
8 |
Correct |
121 ms |
28880 KB |
Output is correct |
9 |
Correct |
110 ms |
20544 KB |
Output is correct |
10 |
Correct |
127 ms |
28708 KB |
Output is correct |
11 |
Correct |
2 ms |
256 KB |
Output is correct |
12 |
Correct |
2 ms |
376 KB |
Output is correct |
13 |
Correct |
3 ms |
504 KB |
Output is correct |
14 |
Correct |
136 ms |
25300 KB |
Output is correct |
15 |
Correct |
138 ms |
25456 KB |
Output is correct |
16 |
Correct |
137 ms |
25328 KB |
Output is correct |
17 |
Correct |
136 ms |
25328 KB |
Output is correct |
18 |
Correct |
135 ms |
25300 KB |
Output is correct |
19 |
Correct |
134 ms |
25328 KB |
Output is correct |
20 |
Correct |
152 ms |
25328 KB |
Output is correct |
21 |
Correct |
14 ms |
1400 KB |
Output is correct |
22 |
Correct |
144 ms |
25328 KB |
Output is correct |
23 |
Correct |
143 ms |
25392 KB |
Output is correct |
24 |
Correct |
141 ms |
25368 KB |
Output is correct |
25 |
Correct |
143 ms |
25340 KB |
Output is correct |
26 |
Correct |
142 ms |
25300 KB |
Output is correct |
27 |
Correct |
143 ms |
25324 KB |
Output is correct |
28 |
Correct |
138 ms |
25328 KB |
Output is correct |
29 |
Correct |
140 ms |
25328 KB |
Output is correct |
30 |
Correct |
139 ms |
25328 KB |
Output is correct |
31 |
Correct |
274 ms |
22976 KB |
Output is correct |
32 |
Correct |
135 ms |
25332 KB |
Output is correct |
33 |
Correct |
218 ms |
25292 KB |
Output is correct |
34 |
Correct |
207 ms |
31372 KB |
Output is correct |
35 |
Correct |
216 ms |
24620 KB |
Output is correct |
36 |
Correct |
207 ms |
31528 KB |
Output is correct |
37 |
Correct |
217 ms |
24364 KB |
Output is correct |
38 |
Correct |
286 ms |
23256 KB |
Output is correct |
39 |
Correct |
154 ms |
25328 KB |
Output is correct |
40 |
Correct |
215 ms |
23240 KB |
Output is correct |
41 |
Correct |
3 ms |
632 KB |
Output is correct |
42 |
Correct |
3 ms |
632 KB |
Output is correct |
43 |
Correct |
3 ms |
636 KB |
Output is correct |
44 |
Correct |
3 ms |
632 KB |
Output is correct |
45 |
Correct |
3 ms |
632 KB |
Output is correct |
46 |
Correct |
3 ms |
632 KB |
Output is correct |
47 |
Correct |
3 ms |
632 KB |
Output is correct |
48 |
Correct |
3 ms |
632 KB |
Output is correct |
49 |
Correct |
3 ms |
632 KB |
Output is correct |
50 |
Correct |
3 ms |
504 KB |
Output is correct |
51 |
Correct |
270 ms |
23328 KB |
Output is correct |
52 |
Correct |
141 ms |
25328 KB |
Output is correct |
53 |
Correct |
218 ms |
23228 KB |
Output is correct |
54 |
Correct |
198 ms |
31400 KB |
Output is correct |
55 |
Correct |
280 ms |
23052 KB |
Output is correct |
56 |
Correct |
142 ms |
25412 KB |
Output is correct |
57 |
Correct |
216 ms |
24024 KB |
Output is correct |
58 |
Correct |
195 ms |
31412 KB |
Output is correct |
59 |
Correct |
328 ms |
23212 KB |
Output is correct |
60 |
Correct |
154 ms |
25328 KB |
Output is correct |
61 |
Correct |
207 ms |
24108 KB |
Output is correct |
62 |
Correct |
203 ms |
31284 KB |
Output is correct |
63 |
Correct |
330 ms |
22828 KB |
Output is correct |
64 |
Correct |
144 ms |
25328 KB |
Output is correct |
65 |
Correct |
211 ms |
24028 KB |
Output is correct |
66 |
Correct |
205 ms |
31396 KB |
Output is correct |
67 |
Correct |
292 ms |
22828 KB |
Output is correct |
68 |
Correct |
141 ms |
25328 KB |
Output is correct |
69 |
Correct |
222 ms |
22836 KB |
Output is correct |
70 |
Correct |
193 ms |
31396 KB |
Output is correct |