Submission #1193406

#TimeUsernameProblemLanguageResultExecution timeMemory
1193406dong_gasSplit the Attractions (IOI19_split)C++20
0 / 100
51 ms15432 KiB
#include <bits/extc++.h>
#define all(v) v.begin(), v.end()
using namespace std;

const int MAXN = 1e5 + 10;

struct disjoint_set {
    vector<int> papa, height, sz;
    disjoint_set(int n) {
        papa.resize(n + 1, -1);
        height.resize(n + 1, 0);
        sz.resize(n + 1, 1);
    }
    int Find(int u) {
        if (papa[u] == -1) return u;
        return papa[u] = Find(papa[u]);
    }
    bool Union(int u, int v) {
        u = Find(u), v = Find(v);
        if (u == v) return false;
        if (height[u] < height[v]) swap(u, v);
        papa[v] = u, sz[u] += sz[v];
        if (height[u] == height[v]) height[u]++;
        return true;
    }
} ds(1);


int n, m, a, b, c, cent, A = -1, B = -1;
int visited[MAXN], papa[MAXN], sz[MAXN], col[MAXN];
vector<int> res, adj[MAXN], tree[MAXN];

int dfs(int u, int p = -1) {
    for (int& v: tree[u]) {
        if (v == p) continue;
        sz[u] += dfs(v, u), papa[v] = u;
    }
    return sz[u];
}

int get_cent(int u, int p, int cnt) {
    for (int& v: tree[u]) {
        if (v == p) continue;
        if (sz[v] > cnt / 2) return get_cent(v, u, cnt);
    }
    return u;
}


void grouping(int u, int p, int nc) {
    col[u] = nc;
    for (int& v: tree[u]) {
        if (v == p) continue;
        ds.Union(u, v), grouping(v, u, nc);
    }
}

void go(int u) {
    // cout << u << endl;
    visited[u] = 1;
    if (a-- > 0) res[u] = 1;
    if (a == 0) return;
    for (int& v: adj[u]) {
        if (visited[v] || ds.Find(u) != ds.Find(v)) continue;
        go(v);
    }
}

void gogo(int u) {
    visited[u] = 1;
    // cout << u << endl;
    if (b-- > 0) res[u] = 2;
    if (b == 0) return;
    for (int& v: adj[u]) {
        if (visited[v]) continue;
        gogo(v);
    }
}

vector<int> find_split(int _n, int _a, int _b, int _c, vector<int> p, vector<int> q) {
    vector<pair<int, int>> t = {{_a, 1}, {_b, 2}, {_c, 3}};
    sort(all(t));
    a = t[0].first, b = t[1].first, n = _n, m = p.size();
    ds = disjoint_set(n);
    vector<pair<int, int>> edges;
    for (int i = 0; i < m; i++) {
        if (ds.Union(p[i], q[i])) tree[p[i]].push_back(q[i]), tree[q[i]].push_back(q[i]);
        else edges.push_back({p[i], q[i]});
        adj[p[i]].push_back(q[i]), adj[q[i]].push_back(p[i]);
    }
    dfs(0), B = cent = get_cent(0, -1, n);
    visited[cent] = 1;
    ds = disjoint_set(n);
    for (int v: adj[cent]) grouping(v, cent, ++c);
    for (int v: adj[cent]) {
        if (ds.sz[ds.Find(v)] >= a) A = v;
    }

    for (auto& [u, v]: edges) {
        if (A != -1) break;
        if (u == cent || v == cent) continue;
        ds.Union(ds.Find(u), ds.Find(v));
        if (ds.sz[ds.Find(u)] >= a) A = ds.Find(u);
    }

    if (A == -1) return vector<int>(n, 0);
    res.resize(n, 3);
    go(A);  // A랑 같은 ds인 놈들만 탐색
    visited[cent] = 0;
    gogo(B);
    // for (int i = 0; i < n; i++) cout << i << ' ' << res[i] << endl;
    for (int i = 0; i < n; i++) res[i] = t[res[i] - 1].second;
    return res;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...