Submission #1288242

#TimeUsernameProblemLanguageResultExecution timeMemory
1288242al95ireyizSplit the Attractions (IOI19_split)C++20
40 / 100
87 ms22964 KiB
#include <bits/extc++.h>
#define all(v) v.begin(), v.end()
using namespace std;

const int MAXN = 1e5 + 10;

struct disjoint_set {
    int papa[MAXN], height[MAXN], sz[MAXN];
    void init(int n) {
        for (int i = 0; i < n; i++) papa[i] = -1, height[i] = 0, sz[i] = 1;
    }
    int Find(int u) {
        if (papa[u] == -1) return u;
        return papa[u] = Find(papa[u]);
    }
    bool Union(int u, int v) {
        u = Find(u), v = Find(v);
        if (u == v) return false;
        if (height[u] < height[v]) swap(u, v);
        papa[v] = u, sz[u] += sz[v];
        if (height[u] == height[v]) height[u]++;
        return true;
    }
} ds;


int n, m, a, b, c, cent, A = -1, B = -1;
int visited[MAXN], papa[MAXN], sz[MAXN];
vector<int> res, adj[MAXN], tree[MAXN];

int dfs(int u, int p = -1) {
    sz[u] = 1;
    for (int& v: tree[u]) {
        if (v == p) continue;
        sz[u] += dfs(v, u), papa[v] = u;
    }
    return sz[u];
}

int get_cent(int u, int p, int cnt) {
    for (int& v: tree[u]) {
        if (v == p) continue;
        if (sz[v] > cnt / 2) return get_cent(v, u, cnt);
    }
    return u;
}


void grouping(int u, int p) {
    for (int& v: tree[u]) {
        if (v == p) continue;
        ds.Union(u, v), grouping(v, u);
    }
}

void go(int u) {
    // cout << u << endl;
    visited[u] = 1;
    if (a-- > 0) res[u] = 1;
    if (a == 0) return;
    for (int& v: adj[u]) {
        if (visited[v] || ds.Find(u) != ds.Find(v) || v == cent) continue;
        if (a > 0) go(v);
    }
}

void gogo(int u) {
    visited[u] = 1;
    // cout << u << endl;
    if (b-- > 0) res[u] = 2;
    if (b == 0) return;
    for (int& v: adj[u]) {
        if (visited[v]) continue;
        if (b > 0) gogo(v);
    }
}

vector<int> find_split(int _n, int _a, int _b, int _c, vector<int> p, vector<int> q) {
    vector<pair<int, int>> t = {{_a, 1}, {_b, 2}, {_c, 3}};
    sort(all(t));
    a = t[0].first, b = t[1].first, c = t[2].first, n = _n, m = p.size();
    ds.init(n);
    vector<pair<int, int>> edges;
    for (int i = 0; i < m; i++) {
        if (ds.Union(p[i], q[i])) tree[p[i]].push_back(q[i]), tree[q[i]].push_back(p[i]);
        else edges.push_back({p[i], q[i]});
        adj[p[i]].push_back(q[i]), adj[q[i]].push_back(p[i]);
    }
    dfs(0), B = cent = get_cent(0, -1, n);
    visited[cent] = 1;
    // for (int i=0;i<n;i++) ds.papa[i] = -1;
    // for (int i = 0; i < n; i++) cout << ds.sz[i] << endl;
    ds.init(n);
    for (int v: tree[cent]) grouping(v, cent);
    for (int v: tree[cent]) {
        if (ds.sz[ds.Find(v)] >= a) A = v;
    }

    // cout << A << endl;
    // for (auto& [u, v]: edges) {
    //     if (A != -1) break;
    //     if (u == cent || v == cent) continue;
    //     ds.Union(u, v);
    //     if (ds.sz[ds.Find(u)] >= a) A = ds.Find(u);
    // }
    // cout << cent << ' ' << A << endl << endl;

    if (A == -1) return vector<int>(n, 0);
    res.resize(n, 3);
    go(A);  // A랑 같은 ds인 놈들만 탐색
    visited[cent] = 0;
    gogo(B);
    // for (int i = 0; i < n; i++) cout << i << ' ' << res[i] << endl;
    for (int i = 0; i < n; i++) res[i] = t[res[i] - 1].second;
    return res;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...