제출 #781971

#제출 시각아이디문제언어결과실행 시간메모리
781971TeaTimeCat Exercise (JOI23_ho_t4)C++17
100 / 100
291 ms67576 KiB
#include <iostream>
#include <vector>
#include <algorithm>
#include <set>
#include <cmath>

using namespace std;

#define fastInp cin.tie(0); cout.tie(0); ios_base::sync_with_stdio(0);

typedef long long ll;
typedef long double ld;

const ll SZ = 200'100, LG = 18;

ll n;
vector<ll> vec;
vector<vector<ll>> gr;

ll up[SZ][LG], tin[SZ], tout[SZ], t = 0;
ll seg[SZ * 8], val[SZ], dp[SZ];

void upd(int v, int l, int r, int pos, int val) {
    if (l == r - 1) {
        seg[v] = val;
    } else {
        int mid = (l + r) / 2;
        if (pos < mid) {
            upd(v * 2 + 1, l, mid, pos, val);
        } else {
            upd(v * 2 + 2, mid, r, pos, val);
        }
        seg[v] = seg[v * 2 + 1] + seg[v * 2 + 2];
    }
}

ll ask(int v, int l, int r, int askl, int askr) {
    if (l >= askr || r <= askl) return 0;

    if (askl <= l && r <= askr) return seg[v];

    int mid = (l + r) / 2;
    return ask(v * 2 + 1, l, mid, askl, askr) + ask(v * 2 + 2, mid, r, askl, askr);
}

void dfs(int v, int par = -1) {
    for (int i = 1; i < LG; i++) {
        up[v][i] = up[up[v][i - 1]][i - 1];
    }

    t++;
    tin[v] = t;

    for (auto to : gr[v]) {
        if (to != par) {
            dp[to] = dp[v] + 1;
            up[to][0] = v;
            dfs(to, v);
        }
    }

    t++;
    tout[v] = t;
}

bool is_par(int v, int u) {
    return tin[v] <= tin[u] && tout[u] <= tout[v];
}

ll lca(int v, int u) {
    if (is_par(v, u)) return v;
    if (is_par(u, v)) return u;

    for (int i = LG - 1; i >= 0; i--) {
        if (!is_par(up[u][i], v)) u = up[u][i];
    }

    return up[u][0];
}

ll dist(int v, int u) {
    ll lc = lca(u, v);
    return dp[u] + dp[v] - 2 * dp[lc];
}

ll sum(int v, int u) {
    ll lc = lca(u, v);
    return ask(0, 0, SZ * 2, 0, tin[u] + 1) + ask(0, 0, SZ * 2, 0, tin[v] + 1) - 2 * ask(0, 0, SZ * 2, 0, tin[lc] + 1) + val[lc];
}

void toggle(int v) {
    val[v] = 1;
    upd(0, 0, 2 * SZ, tin[v], 1);
    upd(0, 0, 2 * SZ, tout[v], -1);
}

ll dsu[SZ], opt[SZ];

ll find(int v) {
    if (dsu[v] == v) return v;
    return dsu[v] = find(dsu[v]);
}

void uni(int v, int u) {
    v = find(v);
    u = find(u);
    if (v != u) {
        dsu[u] = v;
    }
}

int main() {
    fastInp;

    cin >> n;

    vec.resize(n);
    for (auto &c : vec) cin >> c;
    vector<pair<ll, ll>> srt;
    for (int i = 0; i < n; i++) {
        srt.push_back({vec[i], i});
        dsu[i] = i;
    }

    sort(srt.rbegin(), srt.rend());

    gr.resize(n);
    for (int i = 0; i < n - 1; i++) {
        ll u, v;
        cin >> u >> v;
        u--; v--;
        gr[u].push_back(v);
        gr[v].push_back(u);
    }

    dfs(0);

    ll prev = srt[0].second, ans = 0;
    for (int ii = n - 1; ii >= 0; ii--) {
        int i = srt[ii].second;
        for (auto to : gr[i]) {
            if (vec[to] < vec[i]) {
                opt[i] = max(opt[i], opt[find(to)] + dist(find(to), i));
                
                uni(i, to);
            }
        }
    }

    ans = opt[srt[0].second];
    //for (auto c : srt) cout << c.first << " " << c.second << "\n";

    cout << ans;
    return 0;
    toggle(prev);

    srt.erase(srt.begin());
    for (auto c : srt) {
        int ind = c.second;

        cerr << prev << " " << ind << " " << sum(prev, ind) << " " << dist(prev, ind) << "\n";
        if (sum(prev, ind) == 1) {
            ans += dist(prev, ind);
            toggle(ind);
            prev = ind;
        }
    }

    cout << ans;

    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...