Submission #1303866

#TimeUsernameProblemLanguageResultExecution timeMemory
1303866nguynTourism (JOI23_tourism)C++20
100 / 100
526 ms91712 KiB
#include <bits/stdc++.h>

using namespace std;
#ifdef nguyn_debug
#include "debug.h"
#else
#define debug(...)
#endif // LOCAL

#define ll long long
#define F first
#define S second
#define pb push_back
#define pii pair<int, int>
#define sz(x) (int)(x).size()
#define all(x) (x).begin(), (x).end()

const int N = 1e5 + 5;
const int mod = 1e9 + 7;
const ll inf = 1e18;

int tin[N], timedfs, rmq[N][20];
pii rmq_euler[N * 2][20]; 
int h[N], sz[N]; 

int n, m, q; 
int c[N]; 
vector<pii> ev[N];
vector<int> pos[N]; 
vector<int> g[N]; 
vector<pii> add_ev[N]; 
int res[N]; 
set<int> st;

void predfs(int u, int p) {
    tin[u] = ++timedfs; 
    sz[u] = 1;
    rmq_euler[timedfs][0] = {h[u], u};
    for (int v : g[u]) {
        if (v == p) continue;
        h[v] = h[u] + 1;
        predfs(v, u);
        sz[u] += sz[v]; 
        rmq_euler[++timedfs][0] = {h[u], u}; 
    }
}

int get_lca(int u, int v) {
    if (tin[u] > tin[v]) swap(u, v);
    int lg = 31 - __builtin_clz(tin[v] - tin[u] + 1);
    return min(rmq_euler[tin[u]][lg], rmq_euler[tin[v] - (1 << lg) + 1][lg]).S; 
}

int get_range_lca(int u, int v) {
    int lg = 31 - __builtin_clz(v - u + 1); 
    return get_lca(rmq[u][lg], rmq[v - (1 << lg) + 1][lg]); 
}

void add_to_st(int i, int delta) {
    auto it = st.insert(i).F; 
    int pre = *prev(it) + 1;
    int nxt = *next(it) - 1; 
    add_ev[pre].pb({nxt, delta}); 
    add_ev[pre].pb({i - 1, - delta});
    add_ev[i + 1].pb({nxt, - delta}); 
}

void add_node(int u, int p, int delta) {
    for (int i : pos[u]) {
        add_to_st(i, delta); 
    }
    for (int v : g[u]) {
        if (v == p) continue;
        add_node(v, u, delta); 
    }
}

void erase_node(int u, int p) {
    for (int i : pos[u]) st.erase(st.find(i)); 
    for (int v : g[u]) {
        if (v == p) continue;
        erase_node(v, u); 
    }
}

void dfs(int u, int p, int dep, bool keep) {
    int big = 0; 
    // debug(u, dep); 
    for (int v : g[u]) {
        if (v == p) continue;
        if (sz[v] > sz[big]) big = v; 
    }
    for (int v : g[u]) {
        if (v == p || v == big) continue;
        dfs(v, u, 1, 0); 
    }
    if (big) dfs(big, u, dep + 1, 1); 
    for (int i : pos[u]) {
        add_to_st(i, dep); 
    }
    for (int v : g[u]) {
        if (v == p || v == big) continue;
        add_node(v, u, dep); 
    }
    if (!keep) erase_node(u, p); 
}

struct BIT {
    int n;
    vector<int> bit;

    BIT() {}
    BIT (int n) : n(n) {
        bit.assign(n + 3, 0); 
    }

    void update(int id, int val) {
        for (int i = id; i <= n; i += (i & -i)) bit[i] += val; 
    }

    int get(int id) {
        int res = 0;
        for (int i = id; i > 0; i -= (i & -i)) {
            res += bit[i]; 
        }
        return res; 
    }
} bit;

signed main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n >> m >> q;
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        g[u].pb(v);
        g[v].pb(u); 
    }
    predfs(1, 0); 
    for (int i = 1; i < 20; i++) {
        for (int j = 1; j + (1 << i) - 1 <= timedfs; j++) {
            rmq_euler[j][i] = min(rmq_euler[j][i - 1], rmq_euler[j + (1 << (i - 1))][i - 1]); 
        }
    }
    for (int i = 1; i <= m; i++) {
        cin >> c[i]; pos[c[i]].pb(i); 
        rmq[i][0] = c[i]; 
    }
    for (int i = 1; i < 20; i++) {
        for (int j = 1; j + (1 << i) - 1 <= m; j++) {
            rmq[j][i] = get_lca(rmq[j][i - 1], rmq[j + (1 << (i - 1))][i - 1]); 
        }
    }
    for (int i = 1; i <= q; i++) {
        int l, r;
        cin >> l >> r; 
        ev[l].pb({r, i}); 
    }
    st.insert(0);
    st.insert(m + 1); 
    dfs(1, 0, 1, 1); 
    bit = BIT(m); 
    for (int i = 1; i <= m; i++) {
        for (auto it : add_ev[i]) {
            int r = it.F;
            int val = it.S;
            // debug(i, r, val); 
            bit.update(i, val);
            bit.update(r + 1, - val); 
        }
        for (auto it : ev[i]) {
            int r = it.F;
            int id = it.S;
            // debug(bit.get(r), get_range_lca(i, r), i, r); 
            res[id] = bit.get(r) - h[get_range_lca(i, r)]; 
        }
    }
    for (int i = 1; i <= q; i++) {
        cout << res[i] << '\n'; 
    }
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...