Submission #869444

#TimeUsernameProblemLanguageResultExecution timeMemory
869444makravCat Exercise (JOI23_ho_t4)C++14
100 / 100
772 ms230064 KiB
#define _USE_MATH_DEFINES
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef long double ld;
typedef vector<int> vei;
typedef vector<vei> vevei;

#define all(a) (a).begin(), (a).end()
#define sz(a) (int) a.size()
#define con cout << "NO\n"
#define coe cout << "YES\n";
#define str string
#define pb push_back
#define ff first
#define sc second
#define pii pair<int, int>
#define mxe max_element
#define mne min_element
#define stf shrink_to_fit
#define f(i, l, r) for (int i = (l); i < (r); i++)
#define double ld
#define int ll

const int MAXN = 200010;

vector<int> Log_2(MAXN, 0);
void fill() {
    for (int i = 2; i < MAXN; i++) {
        Log_2[i] = Log_2[i / 2] + 1;
    }
}

struct SparseTable {
    int n;
    vector<int> a;
    vector<vector<int>> mx;

    SparseTable(int n_, vector<int> a_) {
        n = n_;
        a = a_;
        mx.resize(n, vector<int>(Log_2[n] + 1, -1));
        build();
    }

    void build() {
        for (int i = n - 1; i >= 0; i--) {
            mx[i][0] = a[i];
            int st = 1;
            while ((1 << st) + i <= n) {
                mx[i][st] = max(mx[(1 << (st - 1)) + i][st - 1], mx[i][st - 1]);
                st++;
            }
        }
    }

    int req(int l, int r) {
        bool sw = false;
        int sz = r - l + 1;
        int otr1 = mx[l][Log_2[sz]];
        int ot2 = mx[r - (1 << Log_2[sz]) + 1][Log_2[sz]];
        return max(otr1, ot2);
    }
};
mt19937 rnd(time(0));

struct node {
    int val, siz, prior, ind, mn, hg;
    node* l = nullptr, * r = nullptr, * par = nullptr;
    node() = default;
    node(int val_, int H, int ind_) {
        prior = rnd();
        val = val_;
        ind = ind_;
        siz = 1;
        mn = H;
        hg = H;
        l = nullptr;
        r = nullptr;
        par = nullptr;
    }
};

vector<pair<node*, node*>> euler;

struct cartesiantree {
    node* r = nullptr;

    int size(node* root) {
        if (root == nullptr) return 0;
        return root->siz;
    }

    int mn(node* root) {
        if (root == nullptr) return 0;
        return root->mn;
    }

    void upd(node* root) {
        if (root == nullptr) return;
        root->siz = size(root->l) + size(root->r) + 1;
        root->mn = max(root->hg, max(mn(root->l), mn(root->r)));
        if (root->l != nullptr) root->l->par = root;
        if (root->r != nullptr) root->r->par = root;
    }

    pair<node*, node*> split(node* root, int x) {
        if (root == nullptr) return { nullptr, nullptr };
        if (size(root->l) < x) {
            pair<node*, node*> p = split(root->r, x - size(root->l) - 1);
            if (p.ff != nullptr)p.ff->par = nullptr;
            if (p.sc != nullptr)p.sc->par = nullptr;
            root->r = p.ff;
            upd(root);
            return { root, p.sc };
        }
        else {
            pair<node*, node*> p = split(root->l, x);
            if (p.ff != nullptr)p.ff->par = nullptr;
            if (p.sc != nullptr)p.sc->par = nullptr;
            root->l = p.sc;
            upd(root);
            return { p.ff, root };
        }
    }

    node* merge(node* a, node* b) {
        if (a == nullptr) return b;
        if (b == nullptr) return a;
        if (a->prior < b->prior) {
            node* x = merge(a, b->l);
            b->l = x;
            upd(b);
            upd(a);
            return b;
        }
        else {
            node* x = merge(a->r, b);
            a->r = x;
            upd(a);
            upd(b);
            return a;
        }
    }

    int get_pos(node* x) {
        int ps = size(x->l);
        while (x->par != nullptr) {
            if (x->par->r == x) ps += size(x->par->l) + 1;
            x = x->par;
        }
        return ps;
    }

    node* getroot(node* x) {
        while (x->par != nullptr) x = x->par;
        return x;
    }

    int kth(node* root, int k) {
        if (root == nullptr)return -1;
        if (size(root->l) == k) {
            cout << (root->par == nullptr ? -1 : root->par->ind) << ' ';
            return root->val;
        }
        if (size(root->l) > k) {
            return kth(root->l, k);
        }
        else {
            return kth(root->r, k - size(root->l) - 1);
        }
    }

    void SPL(int x) {
        int left = get_pos(euler[x].ff), right = get_pos(euler[x].sc);
        auto p = split(getroot(euler[x].ff), left);
        auto p2 = split(p.sc, right - left + 1);
        merge(p.ff, p2.sc);
    }

    void mrg(int x, int y) {
        int left = 1 + get_pos(euler[y].ff);
        auto p = split(getroot(euler[y].ff), left);
        merge(p.ff, merge(getroot(euler[x].ff), p.sc));
    }
};

vector<unordered_set<int>> g;
vector<int> order, par;

void dfs(int v, int p) {
    par[v] = p;
    order.pb(v);
    for (auto& u : g[v]) {
        if (u != p) {
            dfs(u, v);
        }
    }
    order.pb(v);
}

struct LCA {
    int n;
    vector<unordered_set<int>> g;
    vector<vector<int>> up;
    vector<int> tin, tout, h;
    int timer = 0, l;

    LCA() = default;
    LCA(int n_, vector<unordered_set<int>>& g_) {
        n = n_;
        g = g_;
        l = (int)log2(n) + 1;
        up.assign(n, vector<int>(l, 0));
        tin.assign(n, 0);
        tout.assign(n, 0);
        h.assign(n, 0);
        dfs(0, 0, 0);
    }

    void dfs(int v, int p, int hg) {
        h[v] = hg;
        tin[v] = timer++;
        up[v][0] = p;
        f(i, 1, l) {
            up[v][i] = up[up[v][i - 1]][i - 1];
        }

        for (auto& u : g[v]) {
            if (u != p) {
                dfs(u, v, hg + 1);
            }
        }
        tout[v] = timer++;
    }

    bool parent(int a, int b) {
        return tin[a] <= tin[b] && tout[a] >= tout[b];
    }

    int lca(int a, int b) {
        if (parent(a, b)) return a;
        if (parent(b, a)) return b;

        for (int i = l - 1; i >= 0; i--) {
            if (!parent(up[a][i], b)) {
                a = up[a][i];
            }
        }
        return up[a][0];
    }

    int len(int a, int b) {
        return h[a] + h[b] - 2 * h[lca(a, b)];
    }
};

signed main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    fill();

    int n; cin >> n;
    vector<int> h(n);
    f(i, 0, n) cin >> h[i];
    SparseTable sp(n, h);
    vector<int> pos(n + 1, 0);
    for (int i = 0; i < n; i++) {
        pos[h[i]] = i;
    }

    vector<pair<int, int>> e(n - 1);
    f(i, 0, n - 1) cin >> e[i].ff >> e[i].sc;

    g.resize(n);
    par.assign(n, 0);
    for (auto& u : e) {
        g[u.ff - 1].insert(u.sc - 1);
        g[u.sc - 1].insert(u.ff - 1);
    }
    LCA L(n, g);

    dfs(0, 0);
    cartesiantree ct;
    euler.resize(n);
    for (int i = 0; i < sz(order); i++) {
        node* nw = new node(order[i], h[order[i]], i);
        if (euler[order[i]].ff == nullptr) {
            euler[order[i]].ff = nw;
            ct.r = ct.merge(ct.r, euler[order[i]].ff);
        }
        else {
            euler[order[i]].sc = nw;
            ct.r = ct.merge(ct.r, euler[order[i]].sc);
        }
    }
    int ans = 0;

    vector<pair<node*, int>> roots = { {ct.r, 0} };
    int op = 0;
    while (!roots.empty()) {
        op++;

        auto cur = roots.back();
        roots.pop_back();
        int H = ct.mn(cur.ff);
        int v = pos[H];

        int SZ = ct.size(cur.ff);

        if (SZ == 2) {
            ans = max(ans, cur.sc);
        }

        vector<int> cut;
        for (auto& u : g[v]) {
            if (u != par[v]) {
                cut.pb(u);
            }
        }

        for (auto& u : cut) {
            g[v].erase(u);
            g[u].erase(v);
            ct.SPL(u);
            int v2 = pos[ct.mn(ct.getroot(euler[u].ff))];
            roots.pb({ ct.getroot(euler[u].ff), cur.sc + L.len(v, v2) });
        }
        if (!g[v].empty()) {
            ct.SPL(v);
            int u = *g[v].begin();
            g[v].erase(u);
            g[u].erase(v);
            int v2 = pos[ct.mn(ct.getroot(euler[u].ff))];
            roots.pb({ ct.getroot(euler[u].ff), cur.sc + L.len(v, v2) });
        }
    }
    cout << ans << '\n';

    return 0;
}

Compilation message (stderr)

Main.cpp: In member function 'll SparseTable::req(ll, ll)':
Main.cpp:60:14: warning: unused variable 'sw' [-Wunused-variable]
   60 |         bool sw = false;
      |              ^~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...