Submission #531438

#TimeUsernameProblemLanguageResultExecution timeMemory
531438Alex_tz307Werewolf (IOI18_werewolf)C++17
100 / 100
955 ms171436 KiB
#include <bits/stdc++.h>
#include "werewolf.h"

using namespace std;
using vi = vector<int>;

const int kN = 2e5;
int n, m, q, stamp;
vi g[kN];
vector<pair<int, int>> t;

struct DSU {
  vi p;

  void init() {
    p.clear();
    p.resize(n);
    iota(p.begin(), p.end(), 0);
  }

  int root(int x) {
    if (x == p[x]) {
      return x;
    }
    return p[x] = root(p[x]);
  }

  void unite(int x, int y) {
    x = root(x);
    y = root(y);
    if (x == y) {
      return;
    }
    p[x] = p.size();
    p[y] = p[x];
    p.emplace_back(p[x]);
    t.emplace_back(x, y);
  }
};

pair<int, int> dfs(int v) {
  if (v < n) {
    t[v].first = t[v].second = stamp;
    stamp += 1;
    return {stamp - 1, stamp - 1};
  }
  pair<int, int> st = dfs(t[v].first);
  pair<int, int> dr = dfs(t[v].second);
  t[v].first = min(st.first, dr.first);
  t[v].second = max(st.second, dr.second);
  return t[v];
}

struct node {
  int sum;
  node* l;
  node* r;

  node() : sum(0), l(nullptr), r(nullptr) {}
};

void build(node* x, int lx, int rx) {
  if (lx == rx) {
    return;
  }
  int mid = (lx + rx) / 2;
  x->l = new node();
  build(x->l, lx, mid);
  x->r = new node();
  build(x->r, mid + 1, rx);
}

void update(node* prev, node* curr, int lx, int rx, int pos) {
  if (lx == rx) {
    curr->sum += 1;
    return;
  }
  int mid = (lx + rx) / 2;
  if (pos <= mid) {
    curr->r = prev->r;
    curr->l = new node();
    update(prev->l, curr->l, lx, mid, pos);
  } else {
    curr->l = prev->l;
    curr->r = new node();
    update(prev->r, curr->r, mid + 1, rx, pos);
  }
  curr->sum = 0;
  if (curr->l) {
    curr->sum += curr->l->sum;
  }
  if (curr->r) {
    curr->sum += curr->r->sum;
  }
}

int query(node* x, int lx, int rx, int st, int dr) {
  if (st <= lx && rx <= dr) {
    return x->sum;
  }
  int mid = (lx + rx) / 2, ans = 0;
  if (st <= mid) {
    ans += query(x->l, lx, mid, st, dr);
  }
  if (mid < dr) {
    ans += query(x->r, mid + 1, rx, st, dr);
  }
  return ans;
}

node* roots[1 + kN];

struct query_t {
  int x1, x2;
  int y1, y2;
  int index;

  bool operator < (const query_t &rhs) const {
    return x2 < rhs.x2;
  }
};

vi check_validity(int N, vi X, vi Y, vi S, vi E, vi L, vi R) {
  n = N;
  m = X.size();
  q = S.size();
  for (int i = 0; i < m; ++i) {
    int u = X[i], v = Y[i];
    g[u].emplace_back(v);
    g[v].emplace_back(u);
  }
  S.emplace_back(0);
  E.emplace_back(0);
  L.emplace_back(0);
  R.emplace_back(n - 1);
  q += 1;
  vi idx(q);
  iota(idx.begin(), idx.end(), 0);
  sort(idx.begin(), idx.end(), [&](const int &i, const int &j) -> bool {
    return L[i] > L[j];
  });
  DSU dsu;
  dsu.init();
  vi rep(q);
  t.resize(n);
  int u = n - 1;
  for (int qq = 0; qq < q; ++qq) {
    int i = idx[qq];
    while (u >= L[i]) {
      for (int v : g[u]) {
        if (v >= L[i]) {
          dsu.unite(u, v);
        }
      }
      u -= 1;
    }
    rep[i] = dsu.root(S[i]);
  }
  dfs(t.size() - 1);
  vector<query_t> queries(q);
  vector<pair<int, int>> points(n);
  for (int i = 0; i < q; ++i) {
    queries[i].x1 = t[rep[i]].first;
    queries[i].x2 = t[rep[i]].second;
    queries[i].index = i;
  }
  for (int i = 0; i < n; ++i) {
    points[i].first = t[i].first;
  }
  dsu.init();
  t.clear();
  t = vector<pair<int, int>>(n);
  sort(idx.begin(), idx.end(), [&](const int &i, const int &j) -> bool {
    return R[i] < R[j];
  });
  u = 0;
  for (int qq = 0; qq < q; ++qq) {
    int i = idx[qq];
    while (u <= R[i]) {
      for (int v : g[u]) {
        if (v <= R[i]) {
          dsu.unite(u, v);
        }
      }
      u += 1;
    }
    rep[i] = dsu.root(E[i]);
  }
  stamp = 0;
  dfs(t.size() - 1);
  for (int i = 0; i < q; ++i) {
    queries[i].y1 = t[rep[i]].first;
    queries[i].y2 = t[rep[i]].second;
  }
  for (int i = 0; i < n; ++i) {
    points[i].second = t[i].first;
  }
  sort(queries.begin(), queries.end());
  sort(points.begin(), points.end());
  vector<int> sol(q - 1);
  roots[0] = new node();
  build(roots[0], 0, n - 1);
  int ptr = 0;
  for (auto it : queries) {
    if (it.index == q - 1) {
      continue;
    }
    while (ptr < n && points[ptr].first <= it.x2) {
      roots[ptr + 1] = new node();
      update(roots[ptr], roots[ptr + 1], 0, n - 1, points[ptr].second);
      ptr += 1;
    }
    if (query(roots[ptr], 0, n - 1, it.y1, it.y2) - query(roots[it.x1], 0, n - 1, it.y1, it.y2)) {
      sol[it.index] = 1;
    }
  }
  return sol;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...