Submission #1284579

#TimeUsernameProblemLanguageResultExecution timeMemory
1284579PanndaSynchronization (JOI13_synchronization)C++20
100 / 100
217 ms23724 KiB
#include <bits/stdc++.h>
using namespace std;

struct Tree {
    vector<int> depth, parent, siz, head, begin, end, inv_begin;
    Tree(const int &n, const vector<vector<int>> &adj, int root = 0) : depth(n), parent(n), siz(n), head(n), begin(n), end(n), inv_begin(n) {
        auto dfs = [&](auto self, int u, int p) -> void {
            depth[u] = p == -1 ? 0 : depth[p] + 1;
            parent[u] = p;
            siz[u] = 1;
            for (int v : adj[u]) {
                if (v == p) continue;
                self(self, v, u);
                siz[u] += siz[v];
            }
        };
        dfs(dfs, root, -1);
        int tim = 0;
        auto decompose = [&](auto self, int u, int h) -> void {
            head[u] = h;
            int heavy = -1;
            begin[u] = tim++;
            inv_begin[begin[u]] = u;
            for (int v : adj[u]) {
                if (v == parent[u]) continue;
                if (heavy == -1 || siz[v] > siz[heavy]) heavy = v;
            }
            if (heavy != -1) self(self, heavy, h);
            for (int v : adj[u]) {
                if (v == parent[u] || v == heavy) continue;
                self(self, v, v);
            }
            end[u] = tim;
        };
        decompose(decompose, root, root);
    }
    int getDepth(int u) {
        return depth[u];
    }
    int getParent(int u) {
        return parent[u];
    }
    int getAncestor(int u, int k) { // returns the kth ancestor of u, -1 if out of tree. For example: 0th is u, 1st is parent(u),...
        while (k > 0 && u != -1) {
            if (k > depth[u] - depth[head[u]]) {
                k -= depth[u] - depth[head[u]] + 1;
                u = parent[head[u]];
            } else {
                u = inv_begin[begin[u] - k];
                k = 0;
            }
        }
        return u;
    }
    bool isDescendant(int upper, int lower) {
        return begin[upper] <= begin[lower] && begin[lower] < end[upper];
    }
    bool isProperDescendant(int upper, int lower) {
        return begin[upper] < begin[lower] && begin[lower] < end[upper];
    }
    int getLCA(int u, int v) {
        for (; head[u] != head[v]; v = parent[head[v]]) {
            if (depth[head[u]] > depth[head[v]]) swap(u, v);
        }
        if (depth[u] > depth[v]) swap(u, v);
        return u;
    }
    int getDist(int u, int v) {
        return depth[u] + depth[v] - 2 * depth[getLCA(u, v)];
    }
    int getIntermediate(int heavy, int u) { // returns the first node that isn't 'heavy' on the path from 'heavy' to 'u'
        assert(heavy != u);
        if (heavy != getLCA(u, heavy)) {
            return parent[heavy];
        }
        while (head[u] != head[heavy]) {
            u = head[u];
            if (parent[u] == heavy) return u;
            u = parent[u];
        }
        return inv_begin[begin[heavy] + 1];
    }
    int getNode(int u) { // maps nodes -> positions in the Euler tour
        return begin[u];
    }
    int getInvNode(int p) { // maps Euler tour positions -> nodes
        return inv_begin[p];
    }
    array<int, 2> getSubtree(int u) { // returns the subtree rooted at u, presented by the range of positions in the Euler tour
        return array<int, 2>{ begin[u], end[u] };
    }
    vector<array<int, 2>> getPath(int u, int v) { // returns the path [u, v], presented by a list of O(log) ranges that are the positions in the Euler tour
        vector<array<int, 2>> res;
        for (; head[u] != head[v]; v = parent[head[v]]) {
            if (depth[head[u]] > depth[head[v]]) swap(u, v);
            res.push_back({ begin[head[v]], begin[v] + 1 });
        }
        if (depth[u] > depth[v]) swap(u, v);
        res.push_back({ begin[u], begin[v] + 1 });
        return res;
    }
    pair<vector<int>, vector<array<int, 2>>> virtualTree(vector<int> nodes) { // returns a pair: the list of nodes in the virtual tree (at most 2*|nodes|), and the list of edges of that virtual tree
        sort(nodes.begin(), nodes.end(), [&](int u, int v) { return begin[u] < begin[v]; });
        int siz = nodes.size();
        for (int i = 0; i + 1 < siz; i++) {
            nodes.push_back(getLCA(nodes[i], nodes[i + 1]));
        }
        sort(nodes.begin(), nodes.end(), [&](int u, int v) { return begin[u] < begin[v]; });
        nodes.resize(unique(nodes.begin(), nodes.end()) - nodes.begin());
        vector<array<int, 2>> edges;
        vector<int> stk;
        for (int u : nodes) {
            while (!stk.empty() && !isProperDescendant(stk.back(), u)) {
                stk.pop_back();
            }
            if (!stk.empty()) {
                edges.push_back({stk.back(), u});
            }
            stk.push_back(u);
        }
        return {nodes, edges};
    }
};

template<class Info>
struct SegmentTree {
    int n;
    vector<Info> info;
    SegmentTree() : n(0) {}
    SegmentTree(int n_, Info v_ = Info()) {
        init(n_, v_);
    }
    template<class T>
    SegmentTree(vector<T> init_) {
        init(init_);
    }
    void init(int n_, Info v_ = Info()) {
        init(vector<Info>(n_, v_));
    }
    template<class T>
    void init(std::vector<T> init_) {
        n = init_.size();
        info.assign(4 << std::__lg(n), Info());
        std::function<void(int, int, int)> build = [&](int p, int l, int r) {
            if (r - l == 1) {
                info[p] = init_[l];
                return;
            }
            int m = (l + r) / 2;
            build(2 * p, l, m);
            build(2 * p + 1, m, r);
            pull(p);
        };
        build(1, 0, n);
    }
    void pull(int p) {
        info[p] = info[2 * p] + info[2 * p + 1];
    }
    void modify(int p, int l, int r, int x, const Info &v) {
        if (r - l == 1) {
            info[p] = v;
            return;
        }
        int m = (l + r) / 2;
        if (x < m) {
            modify(2 * p, l, m, x, v);
        } else {
            modify(2 * p + 1, m, r, x, v);
        }
        pull(p);
    }
    void modify(int p, const Info &v) {
        modify(1, 0, n, p, v);
    }
    Info rangeQuery(int p, int l, int r, int x, int y) {
        if (l >= y || r <= x) {
            return Info();
        }
        if (l >= x && r <= y) {
            return info[p];
        }
        int m = (l + r) / 2;
        return rangeQuery(2 * p, l, m, x, y) + rangeQuery(2 * p + 1, m, r, x, y);
    }
    Info rangeQuery(int l, int r) {
        return rangeQuery(1, 0, n, l, r);
    }
    template<class F>
    int findFirst(int p, int l, int r, int x, int y, F &&pred) {
        if (l >= y || r <= x) {
            return -1;
        }
        if (l >= x && r <= y && !pred(info[p])) {
            return -1;
        }
        if (r - l == 1) {
            return l;
        }
        int m = (l + r) / 2;
        int res = findFirst(2 * p, l, m, x, y, pred);
        if (res == -1) {
            res = findFirst(2 * p + 1, m, r, x, y, pred);
        }
        return res;
    }
    template<class F>
    int findFirst(int l, int r, F &&pred) {
        return findFirst(1, 0, n, l, r, pred);
    }
    template<class F>
    int findLast(int p, int l, int r, int x, int y, F &&pred) {
        if (l >= y || r <= x) {
            return -1;
        }
        if (l >= x && r <= y && !pred(info[p])) {
            return -1;
        }
        if (r - l == 1) {
            return l;
        }
        int m = (l + r) / 2;
        int res = findLast(2 * p + 1, m, r, x, y, pred);
        if (res == -1) {
            res = findLast(2 * p, l, m, x, y, pred);
        }
        return res;
    }
    template<class F>
    int findLast(int l, int r, F &&pred) {
        return findLast(1, 0, n, l, r, pred);
    }
};

struct Info {
    int mn = 1e9;
    Info operator+(const Info &b) {
        return {min(mn, b.mn)};
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, m, q;
    cin >> n >> m >> q;

    vector<vector<int>> adj(n);
    vector<array<int, 2>> edges;
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        cin >> u >> v;
        u--;
        v--;
        adj[u].push_back(v);
        adj[v].push_back(u);
        edges.push_back({u, v});
    }

    Tree tree(n, adj, 0);
    for (auto &[u, v] : edges) if (tree.getDepth(u) > tree.getDepth(v)) swap(u, v);
    vector<int> active(n, false);
    vector<int> cnt(n, 1);
    vector<int> last(n, 0); // update when turn off
    SegmentTree<Info> seg(n, {0});

    auto head = [&](int u) -> int {
        auto paths = tree.getPath(0, u);
//        reverse(paths.begin(), paths.end());
        for (auto [l, r] : paths) {
            int i = seg.findLast(l, r, [&](Info info) { return info.mn == 0; });
            if (i != -1) return tree.inv_begin[i];
        }
        return 0;
//        while (u != 0 && active[u]) u = tree.getParent(u);
//        return u;
    };
    auto toggle = [&](int i) -> void {
        auto [u, v] = edges[i];
        if (active[v]) {
            active[v] = false; seg.modify(tree.getNode(v), {0});
            last[v] = cnt[v] = cnt[head(u)];
        } else {
            active[v] = true; seg.modify(tree.getNode(v), {1});
            u = head(u);
            cnt[u] += -last[v] + cnt[v];
        }
    };
    auto query = [&](int u) -> int {
        return cnt[head(u)];
    };

    while (m--) {
        int e;
        cin >> e;
        e--;
        toggle(e);
    }

    while (q--) {
        int u;
        cin >> u;
        u--;
        cout << query(u) << '\n';
    }

}

#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...