Submission #1221114

#TimeUsernameProblemLanguageResultExecution timeMemory
1221114shmaxSplit the Attractions (IOI19_split)C++17
18 / 100
2094 ms26440 KiB
#include "split.h"

#include <bits/stdc++.h>

using namespace std;
using i32 = int;
#define int long long
#define len(x) (int)(x.size())
#define inf 1000'000'000'000'000'000LL
#define all(x) x.begin(), x.end()
#define low_bit(x) (x & (-x))

template<typename T>
using vec = vector<T>;

vec<i32> solve_for_tree(int n, int a, int b, int c, vec<vec<int>> g) {
    vec<i32> ans(n, 3);
    vec<int> sizes(n);

    vec<pair<i32, i32>> z = {{a, 1},
                             {b, 2},
                             {c, 3}};
    sort(all(z));
    a = z[0].first;
    int ax = z[0].second;
    b = z[1].first;
    int bx = z[1].second;
    int cx = z[2].second;
    std::fill(ans.begin(), ans.end(), cx);
    int splitter = -1;
    int ps;
    vec<int> SZ = {-1, -1};
    vec<int> Tid = {-1, -1};
    vec<bool> used1(n, false);
    function<void(int, int)> calc_sizes = [&](int v, int p) {
        sizes[v] = 1;
        used1[v] = true;
        for (auto u: g[v]) {
            if (u == p) continue;
            if (used1[u]) continue;
            calc_sizes(u, v);
            sizes[v] += sizes[u];
        }
        if (sizes[v] >= a and (n - sizes[v]) >= b) {
            splitter = v;
            SZ[0] = a;
            SZ[1] = b;
            ps = p;
            Tid[0] = ax;
            Tid[1] = bx;
        }
        if (sizes[v] >= b and (n - sizes[v]) >= a) {
            splitter = v;
            SZ[0] = b;
            SZ[1] = a;
            ps = p;
            Tid[0] = bx;
            Tid[1] = ax;
        }
    };
    calc_sizes(0, -1);
    if (splitter == -1) {
        return vec<i32>(0);
    }
    int need = 0;
    int T = 0;
    vec<bool> used(n, false);
    function<void(int, int)> choose = [&](int v, int p) {
        if (need == 0)
            return;
        if (used[v]) return;
        used[v] = true;
        need--;
        ans[v] = T;
        for (auto u: g[v]) {
            if (u == p) continue;
            choose(u, v);
        }
    };
    need = SZ[0];
    T = Tid[0];
    choose(splitter, ps);
    need = SZ[1];
    T = Tid[1];
    choose(0, -1);
    return ans;
}


struct DSU {
public:
    DSU() : _n(0) {}

    explicit DSU(int n) : _n(n), parent_or_size(n, -1) {}

    int unite(int a, int b) {
        assert(0 <= a && a < _n);
        assert(0 <= b && b < _n);
        int x = leader(a), y = leader(b);
        if (x == y) return x;
        if (-parent_or_size[x] < -parent_or_size[y]) std::swap(x, y);
        parent_or_size[x] += parent_or_size[y];
        parent_or_size[y] = x;
        return x;
    }

    bool one(int a, int b) {
        assert(0 <= a && a < _n);
        assert(0 <= b && b < _n);
        return leader(a) == leader(b);
    }

    int leader(int a) {
        assert(0 <= a && a < _n);
        if (parent_or_size[a] < 0) return a;
        return parent_or_size[a] = leader(parent_or_size[a]);
    }

    int size(int a) {
        assert(0 <= a && a < _n);
        return -parent_or_size[leader(a)];
    }

    std::vector<std::vector<int>> groups() {
        std::vector<int> leader_buf(_n), group_size(_n);
        for (int i = 0; i < _n; i++) {
            leader_buf[i] = leader(i);
            group_size[leader_buf[i]]++;
        }
        std::vector<std::vector<int>> result(_n);
        for (int i = 0; i < _n; i++) {
            result[i].reserve(group_size[i]);
        }
        for (int i = 0; i < _n; i++) {
            result[leader_buf[i]].push_back(i);
        }
        result.erase(
                std::remove_if(result.begin(), result.end(),
                               [&](const std::vector<int> &v) { return v.empty(); }),
                result.end());
        return result;
    }

private:
    int _n;
    // root node: -1 * component size
    // otherwise: parent
    std::vector<int> parent_or_size;
};


vector<i32> find_split(i32 n, i32 a, i32 b, i32 c, vector<i32> p, vector<i32> q) {

    for (int i = 0; i < n; i++) {
        DSU dsu(n);
        vec<vec<int>> g(n);
        vec<int> ord(len(p));
        iota(all(ord), 0);
        shuffle(all(ord), std::mt19937(std::random_device()()));
        for (int j: ord) {
            if (p[j] == i or q[j] == i) continue;
            if (dsu.one(p[j], q[j])) continue;
            dsu.unite(p[j], q[j]);
            g[p[j]].push_back(q[j]);
            g[q[j]].push_back(p[j]);
        }
        shuffle(all(ord), std::mt19937(std::random_device()()));

        for (int j: ord) {
            if (dsu.one(p[j], q[j])) continue;
            dsu.unite(p[j], q[j]);
            g[p[j]].push_back(q[j]);
            g[q[j]].push_back(p[j]);
        }
        vec<i32> res = solve_for_tree(n, a, b, c, g);
        if (!res.empty())
            return res;
    }
    return vec<i32>(n, 0);

//    vec<vec<int>> g(n);
//    for (int i = 0; i < len(p); i++) {
//        g[p[i]].push_back(q[i]);
//        g[q[i]].push_back(p[i]);
//    }
//
//    return solve_for_tree(n, a, b, c, g);
}



#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...