Submission #1031503

#TimeUsernameProblemLanguageResultExecution timeMemory
1031503mdn2002Tourism (JOI23_tourism)C++17
100 / 100
847 ms92760 KiB
/*
Mayoeba Yabureru
*/
#include<bits/stdc++.h>
using namespace std;
struct FenwickTree {
    vector<int> bit;  // binary indexed tree
    int n;

    FenwickTree(int n) {
        this->n = n;
        bit.assign(n, 0);
    }

    int sum(int r) {
        int ret = 0;
        for (; r >= 0; r = (r & (r + 1)) - 1)
            ret += bit[r];
        return ret;
    }

    void add(int idx, int delta) {
        for (; idx < n; idx = idx | (idx + 1))
            bit[idx] += delta;
    }
};

void solve() {
    int n, m, q;
    cin >> n >> m >> q;
    FenwickTree bit(m + 1);
    vector<vector<int>> gr(n + 1);
    vector st(n + 1, vector<int> (20));
    vector<int> dp(n + 1), sum(n + 1);
    for (int i = 1; i < n; i ++) {
        int x, y;
        cin >> x >> y;
        gr[x].push_back(y);
        gr[y].push_back(x);
    }
    function<void(int, int)> dfs = [&] (int x, int p) {
        st[x][0] = p;
        dp[x] = dp[p] + 1;
        sum[x] = 1;
        for (auto u : gr[x]) {
            if (u == p) continue;
            dfs(u, x);
            sum[x] += sum[u];
        }
    };
    dfs(1, 0);
    for (int j = 1; j <= 19; j ++) {
        for (int i = 1; i <= n; i ++) st[i][j] = st[st[i][j - 1]][j - 1];
    }
    function lca = [&] (int x, int y) {
        if (dp[x] > dp[y]) swap(x, y);
        int dif = dp[y] - dp[x], bt = 1;
        for (int i = 0; i < 20; i ++) {
            if ((dif & bt)) y = st[y][i];
            bt *= 2;
        }
        if (x == y) return x;
        for (int i = 19; i >= 0; i --) {
            if (st[x][i] != st[y][i]) {
                x = st[x][i];
                y = st[y][i];
            }
        }
        return st[x][0];
    };

    vector<int> c(m + 1), mx(n + 1), ans(q + 1);
    for (int i = 1; i <= m; i ++) cin >> c[i];
    vector qr(m + 1, vector<pair<int, int>>());
    for (int i = 0; i < q; i ++) {
        int l, r;
        cin >> l >> r;
        qr[r].push_back({l, i});
    }
    vector s(n + 1, set<pair<int, int>>());
    vector add(n + 1, vector<pair<int, int>>());
    vector del(n + 1, vector<int>());
    for (int r = 2; r <= m; r ++) {
        int x = c[r], y = c[r - 1], z = lca(x, y);
        add[x].push_back({r, 0});
        add[y].push_back({r, 1});
        del[z].push_back(r);
    }
    vector fuck(m + 1, vector<vector<pair<int, int>>>(2));
    function<void(int, int)> hld = [&] (int x, int p) {
        int mx = 0, bst = 0;
        for (auto u : gr[x]) {
            if (u == p) continue;
            if (mx < sum[u]) {
                mx = sum[u];
                bst = u;
            }
        }
        for (auto u : gr[x]) {
            if (u == p || u == bst) continue;
            hld(u, x);
        }
        if (bst) {
            hld(bst, x);
            swap(s[x], s[bst]);
        }
        for (auto z : add[x]) s[x].insert(z);
        for (auto u : gr[x]) {
            if (u == p || u == bst) continue;
            for (auto z : s[u]) s[x].insert(z);
        }
        //cout << ' ' << x << endl;
        //for (auto [a, b] : s[x]) cout << ' ' << a << ' ' << b << endl;
        for (auto u : gr[x]) {
            if (u == p || u == bst) continue;
            for (auto z : s[u]) {
                auto it_k = s[x].upper_bound(z);
                if (it_k == s[x].end()) continue;
                auto k = *it_k;
                if (z.first == k.first) continue;
                fuck[k.first][k.second].push_back({z.first, dp[x]});
            }
            for (auto z : s[u]) {
                auto it_k = s[x].lower_bound(z);
                if (it_k == s[x].begin()) continue;
                auto k = *--it_k;
                if (z.first == k.first) continue;
                fuck[z.first][z.second].push_back({k.first, dp[x]});
            }
            s[u].clear();
        }
        for (auto z : add[x]) {
            auto it_k = s[x].upper_bound(z);
            if (it_k == s[x].end()) continue;
            auto k = *it_k;
            if (z.first == k.first) continue;
            fuck[k.first][k.second].push_back({z.first, dp[x]});
        }
        for (auto z : add[x]) {
            auto it_k = s[x].lower_bound(z);
            if (it_k == s[x].begin()) continue;
            auto k = *--it_k;
            if (z.first == k.first) continue;
            fuck[z.first][z.second].push_back({k.first, dp[x]});
        }

        for (auto u : del[x]) {
            s[x].erase({u, 0});
            s[x].erase({u, 1});
        }

        for (auto zz : del[x]) {
            pair z = {zz, 0};
            auto it_k = s[x].upper_bound(z);
            if (it_k == s[x].end()) continue;
            auto k = *it_k;

            if (it_k == s[x].begin()) fuck[k.first][k.second].push_back({0, dp[x] - 1});
            else fuck[k.first][k.second].push_back({(*(--it_k)).first, dp[x] - 1});
        }
    };
    hld(1, 0);
    for (int r = 1; r <= m; r ++) {
        if (r != 1) {
            int x = c[r], y = c[r - 1], z = lca(x, y);
            if (x == z) bit.add(r, dp[y] - dp[x] + 1);
            else if (y == z) bit.add(r, dp[x] - dp[y] + 1);
            else bit.add(r, dp[x] + dp[y] - 2 * dp[z] + 1);
            fuck[r][0].push_back({0, dp[z] - 1});
            fuck[r][1].push_back({0, dp[z]});

            for (int i = 0; i < fuck[r][0].size() - 1; i ++) {
                auto [l, b] = fuck[r][0][i];
                auto [ll, bb] = fuck[r][0][i + 1];
                if (l) bit.add(l, -(b - bb));
            }
            for (int i = 0; i < fuck[r][1].size() - 1; i ++) {
                auto [l, b] = fuck[r][1][i];
                auto [ll, bb] = fuck[r][1][i + 1];
                if (l) bit.add(l, -(b - bb));
            }
        }
        for (auto [l, i] : qr[r]) {
            if (l == r) ans[i] = 1;
            else ans[i] = bit.sum(r) - bit.sum(l);
        }
    }

    for (int i = 0; i < q; i ++) cout << ans[i] << endl;
}
/*
7 6 2
1 2
1 3
2 4
2 5
3 6
3 7
2 3 6 4 5 7
1 3
4 6
*/
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    int T = 1;
    for (int I = 0; I < T; I ++){
        solve();
    }
}

Compilation message (stderr)

tourism.cpp: In function 'void solve()':
tourism.cpp:172:31: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<int, int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  172 |             for (int i = 0; i < fuck[r][0].size() - 1; i ++) {
      |                             ~~^~~~~~~~~~~~~~~~~~~~~~~
tourism.cpp:177:31: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<std::pair<int, int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  177 |             for (int i = 0; i < fuck[r][1].size() - 1; i ++) {
      |                             ~~^~~~~~~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...