Submission #443279

#TimeUsernameProblemLanguageResultExecution timeMemory
443279peijarSplit the Attractions (IOI19_split)C++17
100 / 100
254 ms22408 KiB
#include "split.h"
#include <bits/stdc++.h>
using namespace std;

struct UnionFind {
  vector<int> id;

  UnionFind() {}
  UnionFind(int N) { id.assign(N, -1); }

  int sz(int u) { return -id[find(u)]; }

  int find(int u) {
    if (id[u] < 0)
      return u;
    return id[u] = find(id[u]);
  }

  bool merge(int u, int v) {
    u = find(u), v = find(v);
    if (u == v)
      return false;
    if (id[u] > id[v])
      swap(u, v);
    id[u] += id[v];
    id[v] = u;
    return true;
  }
};

const int MAXN = 2e5;

bitset<MAXN> allowed, seen;
vector<int> adj[MAXN];
vector<pair<int, int>> aretes;
UnionFind dsu(MAXN);
UnionFind dsu2(MAXN);
int sz[MAXN];
int maxSz[MAXN];
int par[MAXN];
int nbSommets, nbAretes;
vector<int> vertices;
int nbRestant;
int centroid = 0;

void dfs(int u, int p) {
  par[u] = p;
  sz[u] = 1;
  maxSz[u] = 0;
  for (int v : adj[u])
    if (v != p) {
      dfs(v, u);
      sz[u] += sz[v];
      maxSz[u] = max(maxSz[u], sz[v]);
    }
}

void dfsRestricted(int u) {
  if (seen[u] or !allowed[u] or !nbRestant)
    return;
  seen[u] = true;
  nbRestant--;
  vertices.push_back(u);
  for (int v : adj[u])
    dfsRestricted(v);
}

vector<int> solve(int a, int b,
                  int c) { // Solves when there is one cc with sz >= min but < 2
                           // * min - 1 or < n / 2
  for (auto [u, v] : aretes)
    adj[u].push_back(v), adj[v].push_back(u);
  int minSz = min({a, b, c});
  int goodCC = -1;
  for (int i = 0; i < nbSommets; ++i)
    if (i != centroid and dsu2.sz(i) >= minSz)
      goodCC = dsu2.find(i);
  for (int i = 0; i < nbSommets; ++i)
    if (goodCC == dsu2.find(i))
      allowed[i] = true;
  nbRestant = minSz;
  dfsRestricted(goodCC);
  assert(!nbRestant);
  int order[3] = {0, 1, 2};
  int sz[3] = {a, b, c};
  vector<int> ret(nbSommets);
  sort(order, order + 3, [&](int i, int j) { return sz[i] < sz[j]; });
  allowed.reset();
  allowed.flip();
  for (int u : vertices) {
    ret[u] = order[0] + 1;
    allowed[u] = false;
  }
  vertices.clear();
  nbRestant = sz[order[1]];
  dfsRestricted(centroid);
  assert(!nbRestant);
  for (int u : vertices)
    ret[u] = order[1] + 1;
  for (int i = 0; i < nbSommets; ++i)
    if (!ret[i])
      ret[i] = order[2] + 1;
  return ret;
};

vector<int> find_split(int n, int a, int b, int c, vector<int> p,
                       vector<int> q) {
  nbSommets = n;
  nbAretes = p.size();
  for (int i = 0; i < nbAretes; ++i) {
    aretes.emplace_back(p[i], q[i]);
    if (dsu.merge(p[i], q[i])) {
      adj[p[i]].push_back(q[i]);
      adj[q[i]].push_back(p[i]);
    }
  }
  int orderSz[3] = {a, b, c};
  sort(orderSz, orderSz + 3);

  dfs(0, 0);

  for (int i = 1; i < nbSommets; ++i)
    if (max(nbSommets - sz[i], maxSz[i]) <
        max(nbSommets - sz[centroid], maxSz[centroid]))
      centroid = i;
  dfs(centroid, centroid);
  for (int i = 0; i < nbSommets; ++i)
    if (par[i] != centroid)
      assert(dsu2.merge(i, par[i]));
  for (int i = 0; i < nbSommets; ++i)
    if (i != centroid and dsu2.sz(i) >= orderSz[0]) {
      // cerr << "HEY1" << endl;
      return solve(a, b, c);
    }

  for (auto [u, v] : aretes)
    if (u != centroid and v != centroid and dsu2.merge(u, v)) {
      if (dsu2.sz(u) >= orderSz[0]) {
        // cerr << "HEY2" << ' ' << dsu2.sz(u) << endl;
        return solve(a, b, c);
      }
    }
  vector<int> ret(nbSommets);
  return ret;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...