제출 #937573

#제출 시각아이디문제언어결과실행 시간메모리
937573zwezdinvCat Exercise (JOI23_ho_t4)C++17
100 / 100
157 ms49076 KiB
#include <bits/stdc++.h>

struct DSU {
        int n;
        std::vector<int> p, mx;

        DSU() = default;
        DSU(int n) : n(n), p(n, 0) {
                std::iota(p.begin(), p.end(), 0);
                mx = p;
        }

        int get(int u) {
                while (u != p[u]) u = p[u] = p[p[u]];
                return u;
        }

        void unite(int u, int v) {
                mx[get(u)] = std::max(mx[get(u)], mx[get(v)]);
                p[get(v)] = get(u);
        }

        int max(int u) {
                return mx[get(u)];
        }
};

struct segtree {
        int n;
        std::vector<std::pair<int, int>> tr;

        segtree() = default;
        segtree(int n) : n(n), tr(2 * n, {1e9, 1e9}) {}
        segtree(std::vector<std::pair<int, int>> a) {
                n = a.size();
                tr.resize(2 * n);
                for (int i = 0; i < n; ++i) {
                        tr[i + n] = a[i];
                }
                for (int i = n - 1; i; --i) {
                        tr[i] = std::min(tr[i << 1], tr[i << 1 | 1]);
                }
        }

        std::pair<int, int> get(int l, int r) {
                std::pair<int, int> res = {1e9, 1e9};
                for (l += n, r += n; l <= r; l >>= 1, r >>= 1) {
                        if (l & 1) res = std::min(res, tr[l++]);
                        if (~r & 1) res = std::min(res, tr[r--]);
                }
                return res;
        }
};

int main() {
        std::cin.tie(nullptr)->sync_with_stdio(false);

        int n;
        std::cin >> n;
        std::vector<int> p(n);
        for (auto& i : p) std::cin >> i, --i;
        std::vector<std::vector<int>> g(n);
        for (int i = 0; i < n - 1; ++i) {
                int u, v;
                std::cin >> u >> v;
                --u, --v;
                u = p[u];
                v = p[v];
                g[u].push_back(v);
                g[v].push_back(u);
        }
        std::vector<int> depth(n), pos(n);
        std::vector<std::pair<int, int>> ss;
        auto dfs = [&](auto dfs, int u, int p) -> void {
                ss.emplace_back(depth[u], u);
                pos[u] = ss.size() - 1;
                for (auto to : g[u]) {
                        if (to != p) {
                                depth[to] = depth[u] + 1;
                                dfs(dfs, to, u);
                                ss.emplace_back(depth[u], u);
                        }
                }
        };
        dfs(dfs, 0, -1);
        segtree tr(ss);
        auto lca = [&](int u, int v) -> int {
                u = pos[u];
                v = pos[v];
                if (u > v) std::swap(u, v);
                return tr.get(u, v).second;
        };
        auto dist = [&](int u, int v) -> int {
                return depth[u] + depth[v] - 2 * depth[lca(u, v)];
        };
        std::vector<long long> dp(n);
        DSU dsu(n);
        for (int i = 0; i < n; ++i) {
                for (auto j : g[i]) {
                        if (j > i) continue;
                        dp[i] = std::max(dp[i], dp[dsu.max(j)] + dist(i, dsu.max(j)));
                        dsu.unite(i, j);
                }
        }
        std::cout << dp[n - 1];
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...