Submission #953286

#TimeUsernameProblemLanguageResultExecution timeMemory
953286makravTourism (JOI23_tourism)C++14
100 / 100
1548 ms200328 KiB
#include <bits/stdc++.h>

using namespace std;
using ll = long long;

#define pb push_back
#define all(a) (a).begin(), (a).end()
#define ff first 
#define sz(x) (int)(x).size()
#define sc second

struct fenwick {
    int n;
    vector<int> t;
    fenwick() = default;
    fenwick(int n_) {
        n = n_;
        t.assign(n + 1, 0);
    }

    void clear() {
        t.assign(n + 1, 0);
    }

    int sum(int x) {
        int ans = 0;
        for (int i = x; i >= 0; i = (i & (i + 1)) - 1) {
            ans += t[i];
        }
        return ans;
    }

    void upd(int pos, int delta) {
        for (int i = pos; i < n; i = i | (i + 1)) {
            t[i] += delta;
        }
    }
};

int sp[18][200010];

signed main() {
    ios_base::sync_with_stdio(false);
    cout.tie(nullptr);
    vector<int> Log2(200010);
    for (int i = 2; i < 200010; i++) {
        Log2[i] = Log2[i / 2] + 1;
    }

    int n, m, q; cin >> n >> m >> q;
    vector<vector<int>> g(n);
    for (int i = 0; i < n - 1; i++) {
        int u, v; cin >> u >> v;
        u--; v--;
        g[u].pb(v);
        g[v].pb(u);
    }

    vector<vector<int>> num(n);
    vector<int> c(m);
    for (int i = 0; i < m; i++) {
        cin >> c[i];
        c[i]--;
        num[c[i]].pb(i);
    }
    vector<pair<int, int>> segs;
    vector<vector<int>> evs;
    vector<set<int>> sub(n);
    unordered_map<ll, int> pos;
    vector<int> h(n), siz(n);

    auto convert = [&](int l, int r) -> ll {
        return l * 1ll * m + r;
    };

    vector<int> htl, vtl, ppos(n);

    auto dfs = [&](int v, int p, auto&&dfs) -> void {
        htl.pb(h[v]);
        vtl.pb(v);
        ppos[v] = sz(vtl) - 1;
        siz[v] = sz(num[v]);
        vector<int> sons;
        for (auto &u : g[v]) {
            if (u != p) {
                h[u] = h[v] + 1;
                sons.pb(u);
                dfs(u, v, dfs);
                htl.pb(h[v]);
                vtl.pb(v);
                siz[v] += siz[u];
            }
        }
        auto add_ev = [&](int l, int r) {
            if (pos.find(convert(l, r)) == pos.end()) {
                pos[convert(l, r)] = sz(segs);
                segs.pb({ l, r });
                evs.pb({});
            }
            if (l == 0 && r == m - 1) {
                evs[pos[convert(l, r)]].pb(h[v]);
                evs[pos[convert(l, r)]].pb(h[v] - 1);
                return;
            }
            evs[pos[convert(l, r)]].pb(h[v]);
            };
        if (sons.empty()) {
            if (num[v].empty()) {
                add_ev(0, m - 1);
                return;
            }
            sub[v].insert(num[v][0]);
            segs.pb({0, num[v][0] - 1});
            evs.pb({h[v]});
            pos[convert(0, num[v][0] - 1)] = sz(segs) - 1;
            segs.pb({num[v][0] + 1, m - 1});
            pos[convert(num[v][0] + 1, m - 1)] = sz(segs) - 1;
            evs.pb({h[v]});
            for (int j = 1; j < sz(num[v]); j++) {
                int el = num[v][j];
                auto it = sub[v].lower_bound(el);
                if (it == sub[v].end()) {
                    int lst = (*sub[v].rbegin());
                    add_ev(lst + 1, m - 1);
                    add_ev(lst + 1, el - 1);
                    add_ev(el + 1, m - 1);
                }
                else if (it == sub[v].begin()) {
                    int bg = *it;
                    add_ev(0, bg - 1);
                    add_ev(el + 1, bg - 1);
                    add_ev(0, el - 1);
                }
                else {
                    int nxt = *it;
                    it--;
                    int pr = *it;
                    add_ev(pr + 1, nxt - 1);
                    add_ev(pr + 1, el - 1);
                    add_ev(el + 1, nxt - 1);
                }
                sub[v].insert(el);
            }
            return;
        }
        int maxson = sons[0];
        for (int i = 1; i < sz(sons); i++) {
            if (siz[sons[i]] > siz[maxson]) maxson = sons[i];
        }
        swap(sub[v], sub[maxson]);
        if (sz(sub[v]) == 0) {
            if (num[v].empty()) {
                add_ev(0, m - 1);
                return;
            }
            sub[v].insert(num[v][0]);
            segs.pb({ 0, num[v][0] - 1 });
            evs.pb({ h[v] });
            pos[convert(0, num[v][0] - 1)] = sz(segs) - 1;
            segs.pb({ num[v][0] + 1, m - 1 });
            pos[convert(num[v][0] + 1, m - 1)] = sz(segs) - 1;
            evs.pb({ h[v] });
            for (int j = 1; j < sz(num[v]); j++) {
                int el = num[v][j];
                auto it = sub[v].lower_bound(el);
                if (it == sub[v].end()) {
                    int lst = (*sub[v].rbegin());
                    add_ev(lst + 1, m - 1);
                    add_ev(lst + 1, el - 1);
                    add_ev(el + 1, m - 1);
                }
                else if (it == sub[v].begin()) {
                    int bg = *it;
                    add_ev(0, bg - 1);
                    add_ev(el + 1, bg - 1);
                    add_ev(0, el - 1);
                }
                else {
                    int nxt = *it;
                    it--;
                    int pr = *it;
                    add_ev(pr + 1, nxt - 1);
                    add_ev(pr + 1, el - 1);
                    add_ev(el + 1, nxt - 1);
                }
                sub[v].insert(el);
            }
            return;
        }
        

        for (int son : sons) {
            if (son != maxson) {
                int num = 0, prev = -1;
                for (int el : sub[son]) {
                    if (num == 0) {
                        add_ev(0, el - 1);
                    }
                    if (num == sz(sub[son]) - 1) {
                        add_ev(el + 1, m - 1);
                    }
                    if (num > 0) {
                        add_ev(prev + 1, el - 1);
                    }
                    prev = el;
                    num++;
                }
                for (int el : sub[son]) {
                    auto it = sub[v].lower_bound(el);
                    if (it == sub[v].end()) {
                        int lst = (*sub[v].rbegin());
                        add_ev(lst + 1, m - 1);
                        add_ev(lst + 1, el - 1);
                        add_ev(el + 1, m - 1);
                    } else if (it == sub[v].begin()) {
                        int bg = *it;
                        add_ev(0, bg - 1);
                        add_ev(el + 1, bg - 1);
                        add_ev(0, el - 1);
                    } else {
                        int nxt = *it;
                        it--;
                        int pr = *it;
                        add_ev(pr + 1, nxt - 1);
                        add_ev(pr + 1, el - 1);
                        add_ev(el + 1, nxt - 1);
                    }
                    sub[v].insert(el);
                }
                sub[son].clear();
            }
        }
        if (!num[v].empty()) {
            for (int el : num[v]) {
                auto it = sub[v].lower_bound(el);
                if (it == sub[v].end()) {
                    int lst = (*sub[v].rbegin());
                    add_ev(lst + 1, m - 1);
                    add_ev(lst + 1, el - 1);
                    add_ev(el + 1, m - 1);
                }
                else if (it == sub[v].begin()) {
                    int bg = *it;
                    add_ev(0, bg - 1);
                    add_ev(el + 1, bg - 1);
                    add_ev(0, el - 1);
                }
                else {
                    int nxt = *it;
                    it--;
                    int pr = *it;
                    add_ev(pr + 1, nxt - 1);
                    add_ev(pr + 1, el - 1);
                    add_ev(el + 1, nxt - 1);
                }
                sub[v].insert(el);
            }
        }
    };
    dfs(0, 0, dfs);
    for (int i = 0; i < sz(vtl); i++) {
        sp[0][i] = i;
    }
    auto comp = [&](int i, int j) {
        return (htl[i] < htl[j] ? i : j);
    };
    for (int pw = 1; pw < 18; pw++) {
        for (int i = 0; i + (1 << pw) - 1 < sz(vtl); i++) {
            sp[pw][i] = comp(sp[pw - 1][i], sp[pw - 1][i + (1 << (pw - 1))]);
        }
    }
    auto get = [&](int l, int r) {
        if (l > r) swap(l, r);
        int lg = Log2[r - l + 1];
        return comp(sp[lg][l], sp[lg][r - (1 << lg) + 1]);
    };

    auto lca = [&](int v, int u) {
        if (v == -1) return u;
        if (u == -1) return v;
        return vtl[get(ppos[v], ppos[u])];
    };

    vector<int> t(4 * m);
    auto build = [&](int v, int tl, int tr, auto&&build) -> void {
        if (tl + 1 == tr) {
            t[v] = c[tl];
            return;
        }
        int tm = (tl + tr) / 2;
        build(v * 2, tl, tm, build);
        build(v * 2 + 1, tm, tr, build);
        t[v] = lca(t[v * 2], t[v * 2 + 1]);
    };  

    auto getlca = [&](int v, int tl, int tr, int l, int r, auto&&getlca) {
        if (l <= tl && tr <= r) { return t[v]; }
        if (tl >= r || tr <= l) {
            return -1;
        }
        int tm = (tl + tr) / 2;
        return lca(getlca(v * 2, tl, tm, l, r, getlca), getlca(v * 2 + 1, tm, tr, l, r, getlca));
    };
    
    build(1, 0, m, build);
    vector<vector<pair<int, int>>> reqs(m), segss(m);
    vector<pair<int, int>> qrs(q);
    vector<int> cntt(sz(segs));
    for (int i = 0; i < sz(segs); i++) {

        for (int j = 0; j < sz(evs[i]); j++) {
            if (j % 2 == 0) cntt[i] += evs[i][j];
            else cntt[i] -= evs[i][j];
        }
        if (sz(evs[i]) & 1) cntt[i]++;
        if (segs[i].ff < m) {
            segss[segs[i].ff].pb({segs[i].sc, cntt[i]});
        }
    }
    for (int i = 0; i < q; i++) {
        int l, r; cin >> l >> r;
        l--; r--;
        qrs[i] = {l, r};
        reqs[l].pb({r, i});
    }
    vector<int> ans(q, 0);
    fenwick fw(m + 2);
    for (int i = 0; i < m; i++) {
        for (auto &u : segss[i]) {
            fw.upd(u.ff + 1, u.sc);
        }
        for (auto &u : reqs[i]) {
            ans[u.sc] = n - (fw.sum(m) - fw.sum(u.ff));
        }
    }
    for (int i = 0; i < q; i++) {
        cout << ans[i] - h[getlca(1, 0, m, qrs[i].ff, qrs[i].sc + 1, getlca)] << '\n';
    }
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...