Submission #868745

#TimeUsernameProblemLanguageResultExecution timeMemory
868745WLZCat Exercise (JOI23_ho_t4)C++17
100 / 100
195 ms62040 KiB
#include <bits/stdc++.h>
using namespace std;

// https://codeforces.com/blog/entry/74847
class lowest_common_ancestor {
    private:
    std::vector< std::vector<int> > g;
    std::vector<int> p, jump, d;

    void dfs(int u) {
        if (d[p[u]] - d[jump[p[u]]] == d[jump[p[u]]] - d[jump[jump[p[u]]]]) jump[u] = jump[jump[p[u]]];

        for (auto v : g[u]) {
            if (v == p[u]) continue;
            jump[v] = p[v] = u;
            d[v] = d[u] + 1;
            dfs(v);
        }
    }

    public:
    lowest_common_ancestor(const std::vector< std::vector<int> > &_g, int root) : g(_g) {
        int n = (int) g.size();
        p.resize(n); jump.resize(n); d.assign(n, 0);
        p[root] = root;
        dfs(root);
    }

    int kth_ancestor(int u, int k) const {
        int v = u;
        while (d[u] - d[v] < k) {
            if (d[u] - d[jump[v]] < k) v = jump[v];
            else v = p[v];
        }
        return v;
    }

    int query(int u, int v) const {
        if (d[u] < d[v]) return query(v, u);
        u = kth_ancestor(u, d[u] - d[v]);
        while (u != v) {
            if (jump[u] == jump[v]) u = p[u], v = p[v];
            else u = jump[u], v = jump[v];
        }
        return u;
    }

    int depth(int u) const { return d[u]; }

    int dist(int u, int v) const { return d[u] + d[v] - 2 * d[query(u, v)]; }
};

class union_find {
    private:
    int n;
    std::vector<int> p, h, largest;

    public:
    union_find() : n(0) {}
    explicit union_find(const std::vector<int> &_h) : n(static_cast<int>(_h.size())), p(n, -1), h(_h), largest(n) {
        std::iota(largest.begin(), largest.end(), 0);
    }

    int root(int a) { return p[a] < 0 ? a : (p[a] = root(p[a])); }

    bool connected(int a, int b) { return root(a) == root(b); }

    int connect(int a, int b) {
        a = root(a); b = root(b);
        if (a == b) return a;
        if (p[a] > p[a]) std::swap(a, b);
        p[a] += p[b]; p[b] = a;
        largest[a] = h[largest[a]] > h[largest[b]] ? largest[a] : largest[b];
        return a;
    }

    int size() const { return n; }

    int size(int a) { return -p[root(a)]; }

    int tallest(int a) { return largest[root(a)]; }
};

int main() {
    ios::sync_with_stdio(false); cin.tie(0);

    int n; cin >> n;
    vector<int> h(n + 1);
    for (int i = 1; i <= n; i++) cin >> h[i];

    vector g(n + 1, vector<int>());
    for (int i = 0; i < n - 1; i++) {
        int a, b; cin >> a >> b;
        g[a].push_back(b); g[b].push_back(a);
    }

    vector<int> ord(n);
    iota(ord.begin(), ord.end(), 1);
    sort(ord.begin(), ord.end(), [&](int a, int b) { return h[a] < h[b]; }); 

    vector<bool> used(n + 1, false);
    vector g2(n + 1, vector<int>());

    union_find uf(h);
    for (const int &u : ord) {
        used[u] = true;
        for (const int &v : g[u]) {
            if (used[v]) {
                g2[u].push_back(uf.tallest(v));
                uf.connect(u, v);
            }
        }
    }

    lowest_common_ancestor lca(g, ord[n - 1]);

    function<long long(int)> solve = [&](int u) {
        long long ans = 0;
        for (const int &v : g2[u]) ans = max(ans, lca.dist(u, v) + solve(v));
        return ans;
    };

    cout << solve(ord[n - 1]) << '\n';

    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...