Submission #1197009

#TimeUsernameProblemLanguageResultExecution timeMemory
1197009madamadam3Werewolf (IOI18_werewolf)C++20
100 / 100
2349 ms361764 KiB
#include "werewolf.h"
#include <bits/stdc++.h>

using namespace std;

#define FOR(i, a, b) for (int i = a; i < b; i++)
#define pb push_back
#define all(x) (x).begin(), (x).end()
#define srt(x) sort(all(x))

typedef long long ll;
using vi = vector<int>;
using vl = vector<ll>;

struct Edge {
    int u, v, w;

    bool operator<(const Edge &other) const {
        return w < other.w;
    }
};

// we need dsu to find the root of node u's current component
struct DSU {
    int n;
    vector<int> par, size;

    DSU(int N) {
        n = N;
        par.assign(n, 0); iota(par.begin(), par.end(), 0);
        size.assign(n, 1); 
    }

    DSU() : n(0) {};

    int find_set(int v) {
        if (par[v] == v) {
            return v;
        }

        return par[v] = find_set(par[v]);
    }

    void unite(int a, int b) {
        a = find_set(a);
        b = find_set(b);

        if (a != b) {
            // if (size[a] < size[b]) swap(a, b); // no union by rank though 
            par[b] = a;
            size[a] += size[b];
        }
    }
};

struct DSUTree {
    int n, m, k; // n = original graph size, m = num edges, k = 2n-1 = kruskal tree size
    DSU dsu;
    vector<vector<int>> adj; // adjlist for dsu tree
    vector<int> value, edgeid; // value of node i (0 for all leaves)
    vector<Edge> edge_list;

    int MAXLOG;
    vector<vector<int>> up;
    int timer;
    vector<int> tin, tout;

    DSUTree(int N, vector<Edge> edges) {
        n = N;
        m = edges.size();
        k = 2 * n - 1;
        edge_list = edges;
        dsu = DSU(k);

        adj.assign(k, vector<int>());
        value.assign(k, 0);
        edgeid.assign(k, 0);

        construct();
    }   

    void dfs(int u, int p) {
        tin[u] = timer++;
        up[u][0] = p;
        for (int i = 1; i < MAXLOG; i++) {
            up[u][i] = up[up[u][i - 1]][i - 1];
        }

        for (auto &v : adj[u]) {
            if (v == p) continue;
            dfs(v, u);
        }

        tout[u] = timer;
    }

    void construct() {
        sort(edge_list.begin(), edge_list.end());
        int cur_anc = n; // who shall we set the ancestor of u and v to for the current merge

        for (int i = 0; i < m; i++) {
            Edge cur = edge_list[i];
            int u = cur.u, v = cur.v, w = cur.w;
            int uhead = dsu.find_set(u), vhead = dsu.find_set(v);

            if (uhead != vhead) {
                dsu.unite(cur_anc, uhead);
                dsu.unite(cur_anc, vhead);

                adj[cur_anc].push_back(uhead);
                adj[cur_anc].push_back(vhead);
                adj[uhead].push_back(cur_anc);
                adj[vhead].push_back(cur_anc);

                value[cur_anc] = w;
                edgeid[cur_anc] = i;
                cur_anc++;
            }
        }

        MAXLOG = 0;
        while (k >= (1 << MAXLOG)) MAXLOG++;

        up.assign(k, vector<int>(MAXLOG, 0));
        tin.assign(k, 0);
        tout.assign(k, 0);
        timer = 0;
        dfs(dsu.find_set(0), dsu.find_set(0));
    }

    bool is_ancestor(int u, int v) {
        return tin[u] <= tin[v] && tout[u] >= tout[v];
    }

    int lca(int u, int v) {
        if (is_ancestor(u, v)) return u;
        if (is_ancestor(v, u)) return v;

        for (int i = MAXLOG - 1; i >= 0; i++) {
            if (!is_ancestor(up[u][i], v)) u = up[u][i];
        }

        return up[u][0];
    }

    int query_pathmax(int u, int v) { // max edge on path from u to v 
        int anc = lca(u, v);
        return value[anc];
    }

    bool can_reach(int u, int v, int x) { // can i reach v from u in first x nodes
        return edgeid[lca(u, v)] < x;
    }
};

struct Node {
    Node *left, *right;
    int sum;

    Node(int val) : left(nullptr), right(nullptr), sum(val) {};

    Node(Node* L, Node* R) {
        left = L;
        right = R;
        sum = 0;

        if (left) sum += left->sum;
        if (right) sum += right->sum; 
    }
};

struct SegTree {
    int n;
    vector<Node*> roots;
    vector<int> arr;

    Node* build(int l, int r) {
        if (l + 1 == r) return new Node(arr[l]);

        int m = l + (r - l) / 2;
        return new Node(build(l, m), build(m, r));
    }

    Node* update(Node *current, int l, int r, int k, int v) {
        if (!(l <= k && k < r)) return current;
        if (l + 1 == r) return new Node(v);

        int m = l + (r - l) / 2;
        return new Node(update(current->left, l, m, k, v), update(current->right, m, r, k, v));
    }

    int query(Node *current, int l, int r, int ql, int qr) {
        if (r <= ql || qr <= l) return 0;
        if (ql <= l && r <= qr) return current->sum;

        int m = l + (r - l) / 2;
        return query(current->left, l, m, ql, qr) + query(current->right, m, r, ql, qr);
    }

    void update(int k, int v) {
        roots.push_back(update(roots.back(), 0, n, k, v));
    }

    int query(int l, int r, int time) {
        return query(roots[time], 0, n, l, r);
    }

    SegTree(int n, vector<int> arr) {
        this->n = n;
        this->arr = arr;

        roots.push_back(build(0, n));
    }
};

/*
  N = num nodes
  X, Y = edges
  S, E = start and end nodes
  L, R = human and werewolf cities
*/

int n, m, q;
vi x, y, s, e, l, r;

vi check_validity(int N, vi X, vi Y, vi S, vi E, vi L, vi R) {
    int n = N, m = X.size(), q = S.size();

    // 1) Build the two reachability KRTs
    vector<Edge> edges_lo(m), edges_hi(m);
    FOR(i,0,m) {
        int u = X[i], v = Y[i];
        edges_lo[i] = {u, v, max(u,v)};     // to test "max vertex ≤ R"
        edges_hi[i] = {u, v, -min(u,v)};    // to test "min vertex ≥ L" (negated)
    }
    DSUTree DLo(n, edges_lo);
    DSUTree DHi(n, edges_hi);

    int K = 2*n - 1;
    auto &tin_lo = DLo.tin,  &tout_lo = DLo.tout;
    auto &tin_hi = DHi.tin,  &tout_hi = DHi.tout;

    // 2) Turn each original city i into a point (x = tin_hi[i], y = tin_lo[i])
    vector<pair<int,int>> pts(n);
    FOR(i,0,n) pts[i] = { tin_hi[i], tin_lo[i] };

    // 3) Bucket the y-coordinates by x so we can build versions in order
    vector<vector<int>> bucket(K);
    for (auto &p : pts) {
        bucket[p.first].push_back(p.second);
    }

    // 4) Build your persistent SegTree:
    //    version[x] = all points with tin_hi < x
    SegTree pst(K, vector<int>(K, 0));
    vector<int> ver(K+1, 0);
    // ver[0] = 0  (roots[0] is the all-zero tree)
    FOR(x0, 0, K) {
        // insert every point whose x == x0
        for (int y : bucket[x0]) {
            pst.update(y, 1);
        }
        // after those inserts, roots.back() is the version for x0+1
        ver[x0+1] = (int)pst.roots.size() - 1;
    }

    // 5) Answer each query:
    vi ans(q, 0);
    FOR(i,0,q) {
        // a) climb in DHi so that min(vertex) ≥ L[i]
        int u = S[i], thr_hi = -L[i];
        for (int d = DHi.MAXLOG-1; d >= 0; d--) {
            int p = DHi.up[u][d];
            if (DHi.value[p] <= thr_hi) u = p;
        }
        int a = tin_hi[u], b = tout_hi[u];

        // b) climb in DLo so that max(vertex) ≤ R[i]
        int v = E[i], thr_lo = R[i];
        for (int d = DLo.MAXLOG-1; d >= 0; d--) {
            int p = DLo.up[v][d];
            if (DLo.value[p] <= thr_lo) v = p;
        }
        int c = tin_lo[v], d2 = tout_lo[v];

        // c) rectangle count = version[b].sum[c,d2) - version[a].sum[c,d2)
        int cntR = pst.query(c, d2, ver[b]);
        int cntL = pst.query(c, d2, ver[a]);
        ans[i] = (cntR - cntL > 0);
    }

    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...