Submission #596930

#TimeUsernameProblemLanguageResultExecution timeMemory
596930MilosMilutinovicJail (JOI22_jail)C++14
0 / 100
24 ms720 KiB
/**
 *    author:  wxhtzdy
 *    created: 08.07.2022 12:20:42
**/
#include <bits/stdc++.h>

using namespace std;

int main() {
  ios::sync_with_stdio(false);
  cin.tie(0);  
  int tt;
  cin >> tt;
  while (tt--) {
    int n;
    cin >> n;
    vector<vector<int>> g(n);
    for (int i = 0; i < n - 1; i++) {
      int u, v;
      cin >> u >> v;
      --u; --v;
      g[u].push_back(v);
      g[v].push_back(u);
    }
    int m;
    cin >> m;
    vector<int> s(m), t(m);
    for (int i = 0; i < m; i++) {
      cin >> s[i] >> t[i];
      --s[i]; --t[i];
    }
    const int L = 20;
    vector<vector<int>> jump(n, vector<int>(L));
    vector<int> dep(n);
    vector<int> tin(n);
    vector<int> tout(n);
    int T = 0;
    function<void(int, int)> Dfs = [&](int v, int pr) {
      dep[v] = dep[pr] + 1;
      tin[v] = ++T;
      jump[v][0] = pr;
      for (int u : g[v]) {
        if (u != pr) {
          Dfs(u, v);
        }
      }
      tout[v] = T;      
    };
    Dfs(0, 0);
    for (int j = 1; j < L; j++) {
      for (int i = 0; i < n; i++) {
        jump[i][j] = jump[jump[i][j - 1]][j - 1];
      }
    }
    auto LCA = [&](int u, int v) {
      if (dep[u] < dep[v]) {
        swap(u, v);
      }                    
      for (int i = L - 1; i >= 0; i--) {
        if (dep[jump[u][i]] >= dep[v]) {
          u = jump[u][i];
        }
      }
      for (int i = L - 1; i >= 0; i--) {
        if (jump[u][i] != jump[v][i]) {
          u = jump[u][i];
          v = jump[v][i];
        }
      }
      return u == v ? u : jump[u][0];
    };
    auto isPar = [&](int a, int b) {
      return tin[a] <= tin[b] && tout[b] <= tout[a];
    };
    auto onPath = [&](int a, int b, int c) {
      int L = LCA(b, c);
      return isPar(L, a) && (isPar(a, b) || isPar(a, c));
    };
    vector<set<pair<int, int>>> st(8 * n);                                      
    function<void(int, int, int, int, int, int, int)> ins = [&](int v, int l, int r, int ll, int rr, int x, int i) {
      if (l > r || l > rr || r < ll) {
        return;
      }
      if (ll <= l && r <= rr) {
        st[v].emplace(x, i);
        return;
      }
      int mid = l + r >> 1;
      ins(v * 2, l, mid, ll, rr, x, i);
      ins(v * 2 + 1, mid + 1, r, ll, rr, x, i);
    };
    function<void(int, int, int, int, int, int, int)> rem = [&](int v, int l, int r, int ll, int rr, int x, int i) {
      if (l > r || l > rr || r < ll) {
        return;
      }
      if (ll <= l && r <= rr) {
        st[v].erase(st[v].find({x, i}));
        return;
      }
      int mid = l + r >> 1;
      rem(v * 2, l, mid, ll, rr, x, i);
      rem(v * 2 + 1, mid + 1, r, ll, rr, x, i);
    };
    function<pair<int, int>(int, int, int, int)> get = [&](int v, int l, int r, int i) {
      if (l == r) {
        if (st[v].empty()) {
          return make_pair(-1, -1);
        } else {
          return *prev(st[v].end());
        }                 
      }
      int mid = l + r >> 1;
      pair<int, int> ret = make_pair(-1, -1);
      if (i <= mid) {                       
        ret = max(ret, get(v * 2, l, mid, i));
      } else {
        ret = max(ret, get(v * 2 + 1, mid + 1, r, i));
      }
      return ret;
    };
    vector<set<int>> alive(8 * n);
    function<void(int, int, int, int, int, int)> act = [&](int v, int l, int r, int ll, int rr, int x) {
      if (l > r || l > rr || r < ll) {
        return;
      }
      if (ll <= l && r <= rr) {
        alive[v].emplace(x);
        return;
      }
      int mid = l + r >> 1;
      act(v * 2, l, mid, ll, rr, x);
      act(v * 2 + 1, mid + 1, r, ll, rr, x);
    };
    function<void(int, int, int, int, int, int)> dec = [&](int v, int l, int r, int ll, int rr, int x) {
      if (l > r || l > rr || r < ll) {
        return;
      }
      if (ll <= l && r <= rr) {
        alive[v].erase(alive[v].find(x));
        return;
      }
      int mid = l + r >> 1;
      dec(v * 2, l, mid, ll, rr, x);
      dec(v * 2 + 1, mid + 1, r, ll, rr, x);
    };
    vector<int> vis(n);
    bool ok = true;
    function<void(int, int, int, int, int)> go = [&](int v, int l, int r, int i, int x) {
      if (!alive[v].empty()) {
        if (*prev(alive[v].end()) >= x) {
          ok = false;
        }
      }
      if (l == r) {
        return;
      }
      int mid = l + r >> 1;
      if (i <= mid) {
        go(v * 2, l, mid, i, x);
      } else {
        go(v * 2 + 1, mid + 1, r, i, x);
      }
    };
    function<void(int)> Go = [&](int i) {
      vis[i] = 1;         
      rem(1, 1, T, tin[s[i]], tout[s[i]], dep[s[i]], i);
      act(1, 1, T, tin[t[i]], tout[t[i]], dep[t[i]]);
      int L = LCA(s[i], t[i]);
      while (true) {
        pair<int, int> p = get(1, 1, T, tin[s[i]]);
        if (p.first < dep[L]) {
          break;
        }
        Go(p.second);
      }
      while (true) {
        pair<int, int> p = get(1, 1, T, tin[t[i]]);
        if (p.first < dep[L]) {
          break;
        }
        Go(p.second);
      }
      dec(1, 1, T, tin[t[i]], tout[t[i]], dep[t[i]]);  
      go(1, 1, T, tin[s[i]], dep[L]);
      go(1, 1, T, tin[t[i]], dep[L]);
      act(1, 1, T, tin[t[i]], tout[t[i]], dep[t[i]]);
      if (!ok) {
        return;
      }
      vis[i] = 2;
    };
    vector<int> order(m);
    iota(order.begin(), order.end(), 0);
    sort(order.begin(), order.end(), [&](int i, int j) {
      if (i == j) {
        return false;
      }
      if (onPath(s[i], s[j], t[j])) {
        return true;
      }
      if (onPath(s[j], s[i], t[i])) {
        return false;
      }
      if (onPath(t[i], s[j], t[j])) {
        return false;
      }
      if (onPath(t[j], s[i], t[i])) {
        return true;
      }
      return tin[t[i]] > tin[t[j]];
    });
    for (int i : order) {
      ins(1, 1, T, tin[s[i]], tout[s[i]], dep[s[i]], i);
    }
    for (int i : order) {
      if (vis[i] == 0) {
        Go(i);
      }
    }
    cout << (ok ? "Yes" : "No") << '\n';
  }                                                                    
  return 0;
}

Compilation message (stderr)

jail.cpp: In lambda function:
jail.cpp:88:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   88 |       int mid = l + r >> 1;
      |                 ~~^~~
jail.cpp: In lambda function:
jail.cpp:100:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
  100 |       int mid = l + r >> 1;
      |                 ~~^~~
jail.cpp: In lambda function:
jail.cpp:112:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
  112 |       int mid = l + r >> 1;
      |                 ~~^~~
jail.cpp: In lambda function:
jail.cpp:130:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
  130 |       int mid = l + r >> 1;
      |                 ~~^~~
jail.cpp: In lambda function:
jail.cpp:142:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
  142 |       int mid = l + r >> 1;
      |                 ~~^~~
jail.cpp: In lambda function:
jail.cpp:157:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
  157 |       int mid = l + r >> 1;
      |                 ~~^~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...