답안 #578111

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
578111 2022-06-16T05:10:24 Z Kanon Simurgh (IOI17_simurgh) C++14
0 / 100
1 ms 468 KB
#include <bits/stdc++.h>
#include "simurgh.h"

using namespace std;

class dsu {
 public:
  vector<int> p;
  int n;

  dsu(int _n) : n(_n) {
    p.resize(n);
    iota(p.begin(), p.end(), 0);
  }

  inline int get(int x) {
    return (x == p[x] ? x : (p[x] = get(p[x])));
  }

  inline bool unite(int x, int y) {
    x = get(x);
    y = get(y);
    if (x != y) {
      p[x] = y;
      return true;
    }
    return false;
  }
};

vector<int> find_roads(int n, vector<int> n1, vector<int> n2) {
  int m = n1.size();
  vector<vector<int>> g(n);
  for (int i = 0; i < m; i++) {
    g[n1[i]].push_back(i);
    g[n2[i]].push_back(i);
  }

  vector<int> par(n, -1);
  vector<int> dep(n);
  vector<int> was(n);
  function<void(int, int)> dfs = [&](int v, int p) {
    was[v] = 1;
    par[v] = p;
    for (int e : g[v]) {
      if (e == p) {
        continue;
      }
      int u = n1[e] ^ n2[e] ^ v;
      if (was[u]) {
        continue;
      }
      dep[u] = dep[v] + 1;
      dfs(u, e);
    }
  };
  dfs(0, -1);

  set<int> tree_edges;
  for (int i = 0; i < n; i++) {
    if (par[i] != -1) {
      tree_edges.insert(par[i]);
    }
  }

  vector<int> ret(m, -1);
  auto royal = [&](set<int> S) {
    vector<int> p;
    for (int i : S) {
      p.push_back(i);
    }
    return count_common_roads(p);
  };



  {
    auto handle_cycle = [&](vector<int> c) {
      assert(c.size() >= 3);
      set<int> S = tree_edges;

      int e = -1, ve = -1;
      for (int i : c) {
        if (ret[i] != -1) {
          e = i;
          ve = ret[i];
        }
        S.insert(i);
      }
      assert((int) S.size() == n);

      if (e != -1) {
        S.erase(e);
        int vS = royal(S) + ve;
        S.insert(e);
        for (int i : c) {
          if (ret[i] != -1) {
            continue;
          }
          S.erase(i);
          ret[i] = vS - royal(S);
          S.insert(i);
        }
      } else {
        vector<int> val;
        for (int i : c) {
          S.erase(i);
          val.push_back(royal(S));
          S.insert(i);
        }
        int mx = *max_element(val.begin(), val.end());
        for (int i = 0; i < (int) c.size(); i++) {
          if (val[i] == mx) {
            ret[c[i]] = 0;
          } else {
            ret[c[i]] = 1;
          }
        }
      }
    };

    function<vector<int>(int)> calc = [&](int v) {
      vector<int> cycle;
      int highest = dep[v] - 1;
      int back_ed = -1;
      for (int e : g[v]) {
        int u = n1[e] ^ n2[e] ^ v;
        if (dep[u] < highest) {
          highest = dep[u];
          back_ed = e;
        }
      }

      if (back_ed != -1) {
        int cur = v;
        while (highest < dep[cur]) {
          int e = par[cur];
          cycle.push_back(e);
          cur = n1[e] ^ n2[e] ^ cur;
        }
        cycle.push_back(back_ed);
      }

      for (int ed : g[v]) {
        int u = n1[ed] ^ n2[ed] ^ v;
        if (ed != par[u]) {
          continue;
        }
        vector<int> now = calc(u);
        if (now.empty()) {
          continue;
        }
        int val = n;
        for (int e : now) {
          val = min(val, min(dep[n1[e]], dep[n2[e]]));
        }
        if (val < highest) {
          cycle = now;
          highest = val;
        }
      }

      if (!cycle.empty()) {
        handle_cycle(cycle);
      } else {
        if (par[v] != -1) {
          ret[par[v]] = 1;
        }
      }
      return cycle;
    };

    calc(0);
  }






  auto make_tree = [&](vector<int> edges) {
    dsu d(n);
    set<int> S;
    int v = 0;
    for (int e : edges) {
      assert(d.unite(n1[e], n2[e]));
      S.insert(e);
    }
    for (int e : tree_edges) {
      assert(ret[e] != -1);
      if (d.unite(n1[e], n2[e])) {
        S.insert(e);
        v += ret[e];
      }
    }
    return make_pair(S, v);
  };

  function<void(set<int>)> calc = [&](set<int> nodes) {
    if (nodes.size() == 1) {
      return;
    }

    int divide = -1;
    for (int e : tree_edges) {
      if (nodes.find(n1[e]) != nodes.end() && nodes.find(n2[e]) != nodes.end()) {
        divide = e;
      }
    }
    assert(divide != -1);

    dsu d(n);
    for (int e : tree_edges) {
      if (e != divide) {
        d.unite(n1[e], n2[e]);
      }
    }

    int a = n1[divide], b = n2[divide];
    set<int> A, B;
    for (int i : nodes) {
      if (d.get(a) == d.get(i)) {
        A.insert(i);
      } else {
        B.insert(i);
      }
    }

    calc(A); calc(B);

    if (A.size() > B.size()) {
      swap(A, B);
    }
    vector<int> order;

    for (int v : A) {
      for (int e : g[v]) {
        int u = n1[e] ^ n2[e] ^ v;
        if (B.find(u) == B.end() || ret[e] != -1) {
          continue;
        }
        order.push_back(e);
      }
    }

    vector<vector<int>> forest;
    set<pair<int, int>> dead;
    for (int i = 0; i < (int) order.size(); i++) {
      dead.insert({i, order[i]});
    }

    while (!dead.empty()) {
      vector<int> now;
      d = dsu(n);
      vector<pair<int, int>> rev;
      for (auto pe : dead) {
        int e = pe.second;
        if (ret[e] != -1) {
          continue;
        }
        if (d.unite(n1[e], n2[e])) {
          now.push_back(e);
          rev.push_back(pe);
        }
      }
      for (auto pe : rev) {
        dead.erase(pe);
      }
      forest.push_back(now);
    }

    for (vector<int> edges : forest) {
      int sz = edges.size();
      auto pi = make_tree(edges);
      int Valued = royal(pi.first) - pi.second;
      if (Valued == 0) {
        for (int e : edges) {
          ret[e] = 0;
        }
        continue;
      }

      function<void(int, int, int)> solve = [&](int L, int R, int val) {
        if (L == R) {
          ret[edges[L]] = val;
          return;
        }

        int mid = (L + R) >> 1;
        vector<int> now;
        for (int i = L; i <= mid; i++) {
          now.push_back(edges[i]);
        }

        auto Pi = make_tree(now);
        int t = royal(Pi.first) - Pi.second;

        if (val - t > 0) {
          solve(mid + 1, R, val - t);
        } else {
          for (int i = mid + 1; i <= R; i++) {
            ret[edges[i]] = 0;
          }
        }

        if (t > 0) {
          solve(L, mid, t);
        } else {
          for (int i = L; i <= mid; i++) {
            ret[edges[i]] = 0;
          }
        }
      };

      solve(0, sz - 1, Valued);
    }

  };

  set<int> nodes;
  for (int i = 0; i < n; i++) {
    nodes.insert(i);
  }

  calc(nodes);

  vector<int> res;
  for (int i = 0; i < m; i++) {
    assert(ret[i] != -1);
    if (ret[i] == 1) {
      res.push_back(i);
    }
  }
  
  assert((int) res.size() == n - 1);

  return res;

}

Compilation message

simurgh.cpp: In lambda function:
simurgh.cpp:219:25: warning: unused variable 'b' [-Wunused-variable]
  219 |     int a = n1[divide], b = n2[divide];
      |                         ^
# 결과 실행 시간 메모리 Grader output
1 Runtime error 1 ms 340 KB Execution killed with signal 6
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Runtime error 1 ms 340 KB Execution killed with signal 6
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Runtime error 1 ms 340 KB Execution killed with signal 6
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 212 KB correct
2 Runtime error 1 ms 468 KB Execution killed with signal 6
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Runtime error 1 ms 340 KB Execution killed with signal 6
2 Halted 0 ms 0 KB -