Submission #1083873

#TimeUsernameProblemLanguageResultExecution timeMemory
1083873horiseunWerewolf (IOI18_werewolf)C++17
100 / 100
1597 ms273944 KiB
#include <iostream>
#include <vector>
#include <numeric>
#include <cassert>
#include <map>
#include <tuple>
#include <functional>
#include <algorithm>
#include <unordered_map>
#include "werewolf.h"
using namespace std;

struct Node {
  int l, r, sm;
  Node *lft, *rht;
  Node(int tl, int tr): l(tl), r(tr), sm(0) {
    if (l + 1 != r) {
      lft = new Node(l, (l + r) / 2);
      rht = new Node((l + r) / 2, r);
    } else {
      lft = rht = NULL;
    }
  }
  void update(int pos) {
    if (pos < l || r <= pos) {
      return;
    }
    if (l + 1 == r) {
      sm++;
      return;
    }
    lft->update(pos);
    rht->update(pos);
    sm = lft->sm + rht->sm;
  }
  int query(int ql, int qr) {
    if (qr <= l || r <= ql) {
      return 0;
    }
    if (ql <= l && r <= qr) {
      return sm;
    }
    return lft->query(ql, qr) + rht->query(ql, qr);
  }
};

int par[400005], processed[200005], jmp1[400005][20], jmp2[400005][20], jmpmn[400005][20], jmpmx[400005][20], mn[400005], mx[400005], in1[400005], out1[400005], in2[400005], out2[400005], pos[200005];
vector<int> adj[200005], krt1[400005], krt2[400005], ans;
vector<pair<int, int>> baseRange1, baseRange2, updates[200005];
vector<tuple<int, int, int, int>> queries;
map<tuple<int, int, int>, int> ret;
Node *root;

vector<int> check_validity(int N, vector<int> X, vector<int> Y, vector<int> S, vector<int> E, vector<int> L, vector<int> R) {
  for (int i = 0; i < (int) X.size(); i++) {
    adj[X[i]].push_back(Y[i]);
    adj[Y[i]].push_back(X[i]);
  }
  iota(par, par + 2 * N, 0);
  function<int(int)> find = [&](int x) -> int {
    return x == par[x] ? x : par[x] = find(par[x]);
  };
  int idx = N;
  function<void(int, int, bool)> merge = [&](int x, int y, bool one) -> void {
    x = find(x);
    y = find(y);
    if (x != y) {
      par[x] = par[y] = par[idx] = idx;
      if (one) {
        krt1[idx].push_back(x);
        krt1[idx].push_back(y);
      } else {
        krt2[idx].push_back(x);
        krt2[idx].push_back(y);
      }
      idx++;
    }
  };
  for (int i = N - 1; i >= 0; i--) {
    processed[i] = true;
    for (int j : adj[i]) {
      if (processed[j]) {
        merge(i, j, true);
      }
    }
  }
  int root1 = idx - 1;
  iota(par, par + 2 * N, 0);
  idx = N;
  fill(processed, processed + N, false);
  for (int i = 0; i < N; i++) {
    processed[i] = true;
    for (int j : adj[i]) {
      if (processed[j]) {
        merge(i, j, false);
      }
    }
  }
  int root2 = idx - 1;
  fill(mn, mn + 2 * N, 2e9);
  fill(&jmp1[0][0], &jmp1[0][0] + sizeof(jmp1) / sizeof(jmp1[0][0]), -1);
  fill(&jmpmn[0][0], &jmpmn[0][0] + sizeof(jmpmn) / sizeof(jmpmn[0][0]), 2e9);
  int counter = 0;
  function<void(int)> dfs1 = [&](int x) -> void {
    in1[x] = counter++;
    if (krt1[x].size() == 0) {
      mn[x] = x;
      baseRange1.push_back({in1[x], x});
    }
    for (int i : krt1[x]) {
      dfs1(i);
      mn[x] = min(mn[x], mn[i]);
    }
    for (int i : krt1[x]) {
      jmp1[i][0] = x;
      jmpmn[i][0] = mn[x];
    }
    out1[x] = counter++;
  };
  fill(mx, mx + 2 * N, -1);
  fill(&jmp2[0][0], &jmp2[0][0] + sizeof(jmp2) / sizeof(jmp2[0][0]), -1);
  fill(&jmpmx[0][0], &jmpmx[0][0] + sizeof(jmpmx) / sizeof(jmpmx[0][0]), 2e9);
  counter = 0;
  function<void(int)> dfs2 = [&](int x) -> void {
    in2[x] = counter++;
    if (krt2[x].size() == 0) {
      mx[x] = x;
      baseRange2.push_back({in2[x], x});
    }
    for (int i : krt2[x]) {
      dfs2(i);
      mx[x] = max(mx[x], mx[i]);
    }
    for (int i : krt2[x]) {
      jmp2[i][0] = x;
      jmpmx[i][0] = mx[x];
    }
    out2[x] = counter++;
  };
  dfs1(root1);
  dfs2(root2);
  for (int i = 1; i < 20; i++) {
    for (int j = 0; j < 2 * N; j++) {
      if (jmp1[j][i - 1] != -1) {
        jmp1[j][i] = jmp1[jmp1[j][i - 1]][i - 1];
        jmpmn[j][i] = min(jmpmn[j][i - 1], jmpmn[jmp1[j][i - 1]][i - 1]);
      }
      if (jmp2[j][i - 1] != -1) {
        jmp2[j][i] = jmp2[jmp2[j][i - 1]][i - 1];
        jmpmx[j][i] = max(jmpmx[j][i - 1], jmpmx[jmp2[j][i - 1]][i - 1]);
      }
    }
  }
  for (int i = 0, a, b, c, d; i < (int) S.size(); i++) {
    int curr = S[i];
    for (int j = 19; j >= 0; j--) {
      if (jmp1[curr][j] != -1 && jmpmn[curr][j] >= L[i]) {
        curr = jmp1[curr][j];
      }
    }
    a = lower_bound(baseRange1.begin(), baseRange1.end(), pair{in1[curr], -1}) - baseRange1.begin();
    b = upper_bound(baseRange1.begin(), baseRange1.end(), pair{out1[curr], -1}) - baseRange1.begin() - 1;
    curr = E[i];
    for (int j = 19; j >= 0; j--) {
      if (jmp2[curr][j] != -1 && jmpmx[curr][j] <= R[i]) {
        curr = jmp2[curr][j];
      }
    }
    c = lower_bound(baseRange2.begin(), baseRange2.end(), pair{in2[curr], -1}) - baseRange2.begin();
    d = upper_bound(baseRange2.begin(), baseRange2.end(), pair{out2[curr], -1}) - baseRange2.begin() - 1;
    queries.push_back({a, b, c, d});
  }
  for (auto [a, b, c, d] : queries) {
    updates[d].push_back({a, b});
    if (c - 1 >= 0) {
      updates[c - 1].push_back({a, b});
    }
  }
  for (int i = 0; i < N; i++) {
    pos[baseRange1[i].second] = i;
  }
  root = new Node(0, N + 5);
  for (int i = 0; i < N; i++) {
    root->update(pos[baseRange2[i].second]);
    for (auto [a, b] : updates[i]) {
      ret[{a, b, i}] = root->query(a, b + 1);
    }
  }
  for (auto &[a, b, c, d] : queries) {
    c--;
    if (ret[{a, b, d}] - ret[{a, b, c}] > 0) {
      ans.push_back(1);
    } else {
      ans.push_back(0);
    }
  }
  return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...