제출 #1366938

#제출 시각아이디문제언어결과실행 시간메모리
1366938thanhbinh13Harmonija (COCI25_harmonija)C++20
110 / 110
497 ms113644 KiB
#include <bits/stdc++.h>
using namespace std;

const long long NEG = -(1LL << 60);

struct Mat {
    long long a[5][5];
};

Mat make_neg() {
    Mat m;
    for (int i = 0; i < 5; i++) {
        for (int j = 0; j < 5; j++) {
            m.a[i][j] = NEG;
        }
    }
    return m;
}

Mat make_identity() {
    Mat m = make_neg();
    for (int i = 0; i < 5; i++) {
        m.a[i][i] = 0;
    }
    return m;
}

Mat multiply_mat(const Mat& A, const Mat& B) {
    Mat C = make_neg();

    for (int i = 0; i < 5; i++) {
        for (int j = 0; j < 5; j++) {
            if (A.a[i][j] <= NEG / 2) continue;

            for (int k = 0; k < 5; k++) {
                if (B.a[j][k] <= NEG / 2) continue;

                C.a[i][k] = max(C.a[i][k], A.a[i][j] + B.a[j][k]);
            }
        }
    }

    return C;
}

struct SegNode {
    Mat fwd, rev;
};

SegNode identity_node() {
    SegNode res;
    res.fwd = make_identity();
    res.rev = make_identity();
    return res;
}

SegNode merge_node(const SegNode& L, const SegNode& R) {
    SegNode res;

    res.fwd = multiply_mat(L.fwd, R.fwd);
    res.rev = multiply_mat(R.rev, L.rev);

    return res;
}

int n, q;
vector<vector<int> > g;
vector<long long> c, p;

vector<int> parent_, depth_, sz, heavy, head, pos, node_at_pos;

Mat make_node_matrix(int u) {
    Mat m = make_neg();

    for (int state = 0; state < 5; state++) {
        if (state + 1 < 5) {
            m.a[state][state + 1] = c[u];
        }

        if (state - 1 >= 0) {
            m.a[state][state - 1] = p[u];
        }
    }

    return m;
}

struct SegmentTree {
    int size_;
    vector<SegNode> tree;

    SegmentTree() {
        size_ = 0;
    }

    void init(int n) {
        size_ = 1;
        while (size_ < n) size_ <<= 1;
        tree.assign(2 * size_, identity_node());
    }

    void set_leaf(int index, const Mat& m) {
        int id = size_ + index - 1;

        tree[id].fwd = m;
        tree[id].rev = m;
    }

    void build() {
        for (int i = size_ - 1; i >= 1; i--) {
            tree[i] = merge_node(tree[i << 1], tree[i << 1 | 1]);
        }
    }

    SegNode query(int l, int r) {
        int left = size_ + l - 1;
        int right = size_ + r;

        SegNode res_left = identity_node();
        SegNode res_right = identity_node();

        while (left < right) {
            if (left & 1) {
                res_left = merge_node(res_left, tree[left]);
                left++;
            }

            if (right & 1) {
                right--;
                res_right = merge_node(tree[right], res_right);
            }

            left >>= 1;
            right >>= 1;
        }

        return merge_node(res_left, res_right);
    }
};

SegmentTree seg;

void build_hld() {
    parent_.assign(n + 1, 0);
    depth_.assign(n + 1, 0);
    sz.assign(n + 1, 1);
    heavy.assign(n + 1, -1);
    head.assign(n + 1, 0);
    pos.assign(n + 1, 0);
    node_at_pos.assign(n + 1, 0);

    vector<int> order;
    order.reserve(n);

    stack<int> st;
    st.push(1);

    parent_[1] = 0;
    depth_[1] = 0;

    while (!st.empty()) {
        int u = st.top();
        st.pop();

        order.push_back(u);

        for (int i = 0; i < (int)g[u].size(); i++) {
            int v = g[u][i];

            if (v == parent_[u]) continue;

            parent_[v] = u;
            depth_[v] = depth_[u] + 1;
            st.push(v);
        }
    }

    for (int i = n - 1; i >= 0; i--) {
        int u = order[i];

        sz[u] = 1;
        int best_size = 0;

        for (int j = 0; j < (int)g[u].size(); j++) {
            int v = g[u][j];

            if (parent_[v] != u) continue;

            sz[u] += sz[v];

            if (sz[v] > best_size) {
                best_size = sz[v];
                heavy[u] = v;
            }
        }
    }

    int timer = 0;

    vector<pair<int, int> > starts;
    starts.push_back(make_pair(1, 1));

    while (!starts.empty()) {
        pair<int, int> cur = starts.back();
        starts.pop_back();

        int u = cur.first;
        int h = cur.second;

        int x = u;

        while (x != -1) {
            head[x] = h;
            pos[x] = ++timer;
            node_at_pos[timer] = x;

            for (int i = 0; i < (int)g[x].size(); i++) {
                int v = g[x][i];

                if (parent_[v] == x && v != heavy[x]) {
                    starts.push_back(make_pair(v, v));
                }
            }

            x = heavy[x];
        }
    }
}

Mat query_path(int u, int v) {
    Mat left = make_identity();
    Mat right = make_identity();

    while (head[u] != head[v]) {
        if (depth_[head[u]] >= depth_[head[v]]) {
            SegNode cur = seg.query(pos[head[u]], pos[u]);
            left = multiply_mat(left, cur.rev);

            u = parent_[head[u]];
        } else {
            SegNode cur = seg.query(pos[head[v]], pos[v]);
            right = multiply_mat(cur.fwd, right);

            v = parent_[head[v]];
        }
    }

    if (depth_[u] >= depth_[v]) {
        SegNode cur = seg.query(pos[v], pos[u]);
        left = multiply_mat(left, cur.rev);
    } else {
        SegNode cur = seg.query(pos[u], pos[v]);
        left = multiply_mat(left, cur.fwd);
    }

    return multiply_mat(left, right);
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> n >> q;

    c.assign(n + 1, 0);
    p.assign(n + 1, 0);
    g.assign(n + 1, vector<int>());

    for (int i = 1; i <= n; i++) {
        cin >> c[i];
    }

    for (int i = 1; i <= n; i++) {
        cin >> p[i];
    }

    for (int i = 1; i <= n - 1; i++) {
        int u, v;
        cin >> u >> v;

        g[u].push_back(v);
        g[v].push_back(u);
    }

    build_hld();

    seg.init(n);

    for (int i = 1; i <= n; i++) {
        int u = node_at_pos[i];
        seg.set_leaf(i, make_node_matrix(u));
    }

    seg.build();

    while (q--) {
        int u, v;
        cin >> u >> v;

        Mat res = query_path(u, v);

        long long ans = NEG;

        for (int end_state = 0; end_state < 5; end_state++) {
            ans = max(ans, res.a[2][end_state]);
        }

        cout << ans << '\n';
    }

    return 0;
}
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…