답안 #903704

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
903704 2024-01-11T10:48:25 Z sleepntsheep Tourism (JOI23_tourism) C++17
0 / 100
3 ms 22876 KB
#include <iostream>
#include <cassert>
#include <cstring>
#include <vector>
#include <algorithm>
#include <deque>
#include <set>
#include <utility>
#include <array>

using namespace std;
#define ALL(x) x.begin(), x.end()
#define ShinLena cin.tie(nullptr)->sync_with_stdio(false);
#define N 200005
const int B = 500;

int ans_[N], n, m, q, c[N], tin[N], timer, D[N], P[18][N];
basic_string<int> g[N];

struct qry
{
    int l, r, i;
    bool operator<(const qry &o) const 
    {
        if (l / B != o.l / B) return l < o.l;
        return r > o.r;
    }
} b[N];

void dfs(int u)
{
    tin[u] = timer++;
    for (auto v : g[u])
    {
        if (v == P[0][u]) continue;
        D[v] = D[u] + 1;
        P[0][v] = u;
        dfs(v);
    }
}

int lca(int u, int v)
{
    if (D[u] < D[v]) swap(u, v);
    int dt = D[u] - D[v];
    for (int j = 18; j--;) if (dt & (1 << j)) u = P[j][u];
    if (u == v) return u;
    for (int j = 18; j--;) if (P[j][u] != P[j][v]) u = P[j][u], v = P[j][v];
    return P[0][u];
}

int dist(int u, int v, int lca_)
{
    return D[u] + D[v] - 2 * D[lca_];
}

int dist(int u, int v) { return dist(u, v, lca(u, v)); }




const auto cmptin = [](int u, int v) { return tin[u] < tin[v]; };

set<int, decltype(cmptin)> vt(cmptin);
long long ans = 0;
int freq[N];

void add(int u)
{
    if (!freq[u]++)
    {
        auto it = vt.upper_bound(u);
        if (it != begin(vt)) ans -= dist(*prev(it), *it);

        vt.insert(u);
        it = vt.find(u);

        if (it != begin(vt)) ans += dist(*prev(it), u);
        if (next(it) != end(vt)) ans += dist(u, *next(it));
    }
}

void del(int u)
{
    if (!--freq[u])
    {
        auto it = vt.find(u);

        if (it != begin(vt)) ans -= dist(*prev(it), u);
        if (next(it) != end(vt)) ans -= dist(u, *next(it));
        it = vt.erase(it);
        if (it != end(vt) and it != begin(vt)) ans += dist(*prev(it), *it);
    }
}

int main()
{
    ShinLena;
    cin >> n >> m >> q;
    for (int u, v, i = 1; i < n; ++i) cin >> u >> v, g[--u].push_back(--v), g[v].push_back(u);
    for (int i = 0; i < m; ++i) cin >> c[i], --c[i];
    for (int i = 0; i < q; ++i) cin >> b[i].l >> b[i].r, --b[i].l, --b[i].r, b[i].i = i;
    dfs(0);
    for (int j = 1; j < 18; ++j) for (int i = 1; i <= n; ++i) P[j][i] = P[j-1][P[j-1][i]];

    sort(b, b+q, [&](const qry &a, const qry &b) {
        if (a.l / B != b.l / B) return a.l < b.l;
        return ((a.l / B) & 1) ? a.r > b.r : a.r < b.r;
    });

    int ll = b[0].l, rr = ll - 1;
    for (int i = 0; i < q; ++i)
    {
        auto [l, r, j] = b[i];
        while (ll > l) add(c[--ll]);
        while (rr < r) add(c[++rr]);
        while (ll < l) del(c[ll++]);
        while (rr > r) del(c[rr--]);
        ans_[j] = ans;
    }

    for (int i = 0; i < q; ++i) cout << ans_[i] + 1 << '\n';

    return 0;
}


# 결과 실행 시간 메모리 Grader output
1 Incorrect 3 ms 22872 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 3 ms 22872 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 3 ms 22872 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 3 ms 22876 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 3 ms 22872 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 3 ms 22872 KB Output isn't correct
2 Halted 0 ms 0 KB -