Submission #941526

#TimeUsernameProblemLanguageResultExecution timeMemory
941526peterandvoiCat Exercise (JOI23_ho_t4)C++17
100 / 100
485 ms78676 KiB
#include <bits/stdc++.h>

using namespace std;

#ifdef ngu
#include "debug.h"
#else
#define debug(...) 42
#endif

const int N = (int) 2e5 + 5;
const int INF = (int) 1e6;

int n;
int p[N], h[N];
int tin[N], tout[N], pos[N];
vector<int> g[N];

namespace lca {
    const int LOG = 19;

    int tin[N], tout[N];
    int lg[N << 1];
    int spt[LOG][N << 1];

    int order;

    void dfs(int u, int p) {
        spt[0][tin[u] = ++order] = u;
        for (int v : g[u]) {
            if (v != p) {
                dfs(v, u);
                spt[0][++order] = u;
            }
        }
        tout[u] = order;
    }

    int min_by_time(int u, int v) {
        return tin[u] < tin[v] ? u : v;
    }

    void build() {
        dfs(1, 1);
        for (int i = 2; i <= order; ++i) {
            lg[i] = lg[i / 2] + 1;
        }
        for (int j = 1; j <= lg[order]; ++j) {
            for (int i = 1; i + (1 << j) - 1 <= order; ++i) {
                spt[j][i] = min_by_time(spt[j - 1][i], spt[j - 1][i + (1 << (j - 1))]);
            }
        }
    }

    int get(int u, int v) {
        int l = min(tin[u], tin[v]);
        int r = max(tout[u], tout[v]);
        int k = lg[r - l + 1];
        return min_by_time(spt[k][l], spt[k][r - (1 << k) + 1]);
    }
}

int order;

void pre_dfs(int u) {
    pos[tin[u] = ++order] = u;
    for (int v : g[u]) {
        if (!tin[v]) {
            h[v] = h[u] + 1;
            pre_dfs(v);
        }
    }
    tout[u] = order;
}

int dis(int u, int v) {
    return h[u] + h[v] - 2 * h[lca::get(u, v)];;
}

struct segment_tree {
    int n;
    vector<long long> s, lz;
    vector<int> res;

    segment_tree() {};

    segment_tree(int n): n(n) {
        s.resize(4 << __lg(n));
        lz.resize(4 << __lg(n));
        res.resize(4 << __lg(n));
    }

    void pull(int id) {
        s[id] = max(s[id << 1], s[id << 1 | 1]);
        res[id] = s[id << 1] > s[id << 1 | 1] ? res[id << 1] : res[id << 1 | 1];
    }

    void build(int id, int l, int r) {
        if (l == r) {
            res[id] = pos[l];
            s[id] = p[pos[l]];
            return;
        }
        int mid = l + r >> 1;
        build(id << 1, l, mid);
        build(id << 1 | 1, mid + 1, r);
        pull(id);
    }

    void build() {
        build(1, 1, n);
    }

    void modify(int id, int x) {
        s[id] += x;
        lz[id] += x;
    }

    void push(int id) {
        if (lz[id]) {
            modify(id << 1, lz[id]);
            modify(id << 1 | 1, lz[id]);
            lz[id] = 0;
        }
    }

    void upd(int id, int l, int r, int u, int v, int x) {
        if (u <= l && r <= v) {
            modify(id, x);
            return;
        }
        int mid = l + r >> 1;
        push(id);
        if (u <= mid) {
            upd(id << 1, l, mid, u, v, x);
        }
        if (mid < v) {
            upd(id << 1 | 1, mid + 1, r, u, v, x);
        }
        pull(id);
    }

    void upd(int u, int v, int x) {
        upd(1, 1, n, u, v, x);
    }

    int get() {
        return s[1] > 0 ? res[1] : -1;
    }
} st;

long long solve(int u) {
    long long res = 0;
    st.upd(tin[u], tout[u], -INF);
    int x = st.get();
    if (x != -1) {
        res = max(res, solve(x) + dis(u, x));
    }
    st.upd(tin[u], tout[u], INF);
    for (int v : g[u]) {
        if (tin[v] > tin[u]) {
            if (1 <= tin[v] - 1) {
                st.upd(1, tin[v] - 1, -INF);
            }
            if (tout[v] + 1 <= n) {
                st.upd(tout[v] + 1, n, -INF);
            }
            int x = st.get();
            if (x != -1) {
                res = max(res, solve(x) + dis(u, x));
            }
            if (1 <= tin[v] - 1) {
                st.upd(1, tin[v] - 1, INF);
            }
            if (tout[v] + 1 <= n) {
                st.upd(tout[v] + 1, n, INF);
            }
        }
    }
    return res;
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    #ifdef ngu
    freopen("test.inp", "r", stdin);
    freopen("test.out", "w", stdout);
    #endif
    cin >> n;
    st = segment_tree(n);
    for (int i = 1; i <= n; ++i) {
        cin >> p[i];
    }
    for (int i = 1; i < n; ++i) {
        int u, v;
        cin >> u >> v;
        g[u].emplace_back(v);
        g[v].emplace_back(u);
    }
    lca::build();
    pre_dfs(1);
    st.build();
    cout << solve(st.get());
}

Compilation message (stderr)

Main.cpp: In member function 'void segment_tree::build(int, int, int)':
Main.cpp:104:21: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
  104 |         int mid = l + r >> 1;
      |                   ~~^~~
Main.cpp: In member function 'void segment_tree::upd(int, int, int, int, int, int)':
Main.cpp:132:21: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
  132 |         int mid = l + r >> 1;
      |                   ~~^~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...