Submission #130810

#TimeUsernameProblemLanguageResultExecution timeMemory
130810xanthoElection Campaign (JOI15_election_campaign)C++17
100 / 100
321 ms34272 KiB
#include <algorithm>
#include <iostream>
#include <vector>

class heavy_light {
    struct path {
        int u, v;
        int lca;
        long long value;
    };

    static const int UNCHECKED = -1;
    static const int ROOT = 0; // Standard root

    int num_vertices;
    std::vector<std::vector<int>> adj_list;
    std::vector<int> parent; // Parent of each vertex in rooted tree.
    std::vector<int> depth; // Depth of each vertex.
    std::vector<int> dfs_start; // Time when a vertex's DFS processing starts. Used to check for Ancestors
    std::vector<int> dfs_end; // Time when a vertex's DFS processing ends. Used to check for Ancestors.
    std::vector<int> subtree_size; // Number of vertices in subtree rooted at each vertex.
    std::vector<std::vector<int>> chains; // The i-th vector is the i-th chain
    std::vector<int> chain_index; // Which chain is a vertex part of? (Starts from 0)
    std::vector<int> position_in_chain; // How deep is a vertex in the chain? (0 is closest to the root)

    // DP Variables
    std::vector<std::vector<path>> lca_paths; // The i-th vector represents all paths that has i as the LCA
    std::vector<std::vector<long long>> cumulative_chain_children_memo; // The j-th index of the i-th vector represents the sum of the children's DP values for the j-th vertex and beyond within the i-th chain
    std::vector<std::vector<long long>> cumulative_chain_vertex_memo; // Exactly the same as above, but stores the vertex's DP value instead of the children's

    void dfs(int u, int& time) {
        dfs_start[u] = time;
        subtree_size[u] = 1;

        for (int v : adj_list[u]) {
            if (v == parent[u]) {
                continue;
            }

            parent[v] = u;
            depth[v] = depth[u] + 1;

            ++time;
            dfs(v, time);

            subtree_size[u] += subtree_size[v];
        }

        dfs_end[u] = time;
        ++time;
    }

    void make_rooted_tree() {
        parent[ROOT] = ROOT;
        depth[ROOT] = 0;

        int time = 0;
        dfs(ROOT, time);
    }

    void start_new_chain(int head) {
        int new_chain_id = chains.size();

        chains.push_back({head});
        chain_index[head] = new_chain_id;
        position_in_chain[head] = 0;
    }

    void recursive_construct_chain(int u) {
        int next_in_chain = UNCHECKED;
        for (int v : adj_list[u]) {
            if (v == parent[u]) {
                continue;
            }

            if (next_in_chain == UNCHECKED || subtree_size[next_in_chain] < subtree_size[v]) {
                next_in_chain = v;
            }
        }

        // Case: No children
        if (next_in_chain == UNCHECKED) {
            return;
        }

        // Set up next vertex in chain
        chains[chain_index[u]].push_back(next_in_chain);
        chain_index[next_in_chain] = chain_index[u];
        position_in_chain[next_in_chain] = position_in_chain[u] + 1;
        recursive_construct_chain(next_in_chain);

        // Start new chains in all other children
        for (int v : adj_list[u]) {
            if (v == parent[u] || v == next_in_chain) {
                continue;
            }

            start_new_chain(v);
            recursive_construct_chain(v);
        }
    }

    void construct_chains() {
        start_new_chain(ROOT);
        recursive_construct_chain(ROOT);
    }

    // Returns true if u is an ancestor of v
    bool is_ancestor(int u, int v) {
        return dfs_start[u] <= dfs_start[v] && dfs_end[v] <= dfs_end[u];
    }

    int find_lca(int u, int v) {
        // Step 1: Find chain containing LCA
        int u_chain_index = chain_index[u];
        int u_curr = chains[u_chain_index][0];

        while (!is_ancestor(u_curr, v)) {
            u_curr = parent[u_curr]; // Move to parent chain
            u_chain_index = chain_index[u_curr];
            u_curr = chains[u_chain_index][0]; // Move to head of chain
        }

        int lca_chain_index = chain_index[u_curr];

        int low = 0;
        int high = (int) (chains[lca_chain_index].size()) - 1;
        int lca = u_curr;

        // Step 2: Perform Binary Search on the chain
        while (low <= high) {
            int mid = low + (high - low) / 2;
            int vertex = chains[lca_chain_index][mid];

            if (is_ancestor(vertex, v) && is_ancestor(vertex, u)) {
                lca = vertex;
                low = mid + 1;
            } else {
                high = mid - 1;
            }
        }

        return lca;
    }

    long long process_path(const path& p) {
        int lca_chain_index = chain_index[p.lca];
        int lca_pos_in_chain = position_in_chain[p.lca];

        long long answer = 0;

        int endpoints[] = {p.u, p.v};
        for (int vertex : endpoints) {
            // Step 1: Climb up from curr to LCA
            long long path_value = 0;

            // Step 1a: Climb until the same chain as LCA
            int curr = vertex;
            while (chain_index[curr] != lca_chain_index) {
                int curr_chain_index = chain_index[curr];
                int curr_pos_in_chain = position_in_chain[curr];
                int curr_chain_head = chains[curr_chain_index][0];

                path_value += cumulative_chain_children_memo[curr_chain_index][0] - cumulative_chain_children_memo[curr_chain_index][curr_pos_in_chain + 1];
                path_value -= cumulative_chain_vertex_memo[curr_chain_index][0] - cumulative_chain_vertex_memo[curr_chain_index][curr_pos_in_chain + 1];

                curr = parent[curr_chain_head];
            }

            // Step 1b: Climb from curr to LCA
            {
                int curr_pos_in_chain = position_in_chain[curr];
                path_value += cumulative_chain_children_memo[lca_chain_index][lca_pos_in_chain + 1] - cumulative_chain_children_memo[lca_chain_index][curr_pos_in_chain + 1];
                path_value -= cumulative_chain_vertex_memo[lca_chain_index][lca_pos_in_chain + 1] - cumulative_chain_vertex_memo[lca_chain_index][curr_pos_in_chain + 1];
            }

            answer += path_value;
        }

        answer += (cumulative_chain_children_memo[lca_chain_index][lca_pos_in_chain] - cumulative_chain_children_memo[lca_chain_index][lca_pos_in_chain + 1]) + p.value;
        return answer;
    }

    long long bottom_up_dp(int u) {
        // DP value of u represents the maximum value obtainable within subtree rooted at u.

        // Step 1: Process all children first. Keep track of sum of DP values of children
        long long children_dp = 0;
        for (int v : adj_list[u]) {
            if (v == parent[u]) {
                continue;
            }

            children_dp += bottom_up_dp(v);
        }

        // Step 2: Update cumulative chain children memo
        int u_chain_index = chain_index[u];
        int u_pos_in_chain = position_in_chain[u];

        cumulative_chain_children_memo[u_chain_index][u_pos_in_chain] = cumulative_chain_children_memo[u_chain_index][u_pos_in_chain + 1] + children_dp;

        // Step 3: Process all paths to find DP value of u
        long long u_dp_value = children_dp;

        for (path& p : lca_paths[u]) {
            u_dp_value = std::max(u_dp_value, process_path(p));
        }

        // Step 4: Update cumulative chain vertex memo
        cumulative_chain_vertex_memo[u_chain_index][u_pos_in_chain] = cumulative_chain_vertex_memo[u_chain_index][u_pos_in_chain + 1] + u_dp_value;

        // Step 5: Return
        return u_dp_value;
    }

public:
    heavy_light(int _num_vertices) :
            num_vertices(_num_vertices),
            adj_list(num_vertices),
            parent(num_vertices, UNCHECKED),
            depth(num_vertices, UNCHECKED),
            dfs_start(num_vertices, UNCHECKED),
            dfs_end(num_vertices, UNCHECKED),
            subtree_size(num_vertices, UNCHECKED),
            chain_index(num_vertices, UNCHECKED),
            position_in_chain(num_vertices, UNCHECKED),
            lca_paths(num_vertices) {}

    void add_edge(int u, int v) {
        adj_list[u].push_back(v);
        adj_list[v].push_back(u);
    }

    void decompose() {
        make_rooted_tree();
        construct_chains();
    }

    void print_details() {
        for (int i = 0; i < (int) chains.size(); ++i) {
            std::cout << "Chain #" << i << ": ";
            for (int u : chains[i]) {
                std::cout << u << "-";
            }
            std::cout << "\n";
        }
        std::cout << "\n";

        for (int i = 0; i < num_vertices; ++i) {
            std::cout << "Vertex #" << i
                    << " - DFS times: " << dfs_start[i] << " to " << dfs_end[i]
                    << ", Chain ID: " << chain_index[i]
                    << ", Position in Chain: " << position_in_chain[i] << "\n";
        }
    }
    void add_path(int u, int v, long long value) {
        int lca = find_lca(u, v);
        lca_paths[lca].push_back({u, v, lca, value});
    }

    long long get_answer() {
        for (auto& chain : chains) {
            cumulative_chain_children_memo.emplace_back(chain.size() + 1, 0);
            cumulative_chain_vertex_memo.emplace_back(chain.size() + 1, 0);
        }

        return bottom_up_dp(ROOT);
    }
};

int main() {
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int num_vertices;
    std::cin >> num_vertices;

    heavy_light hld(num_vertices);

    for (int i = 0; i < num_vertices - 1; ++i) {
        int u, v;
        std::cin >> u >> v;

        hld.add_edge(u - 1, v - 1); // 1-indexed vertices
    }

    hld.decompose();

    int num_paths;
    std::cin >> num_paths;

    for (int i = 0; i < num_paths; ++i) {
        int u, v;
        long long value;

        std::cin >> u >> v >> value;

        hld.add_path(u - 1, v - 1, value); // 1-indexed vertices
    }

    std::cout << hld.get_answer() << "\n";
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...