제출 #1236132

#제출 시각아이디문제언어결과실행 시간메모리
1236132ProtonDecay314Split the Attractions (IOI19_split)C++20
100 / 100
83 ms25268 KiB
#include "split.h"
#include<bits/stdc++.h>
using namespace std;
#define double long double
typedef long long ll;
typedef vector<ll> vll;
typedef vector<vll> vvll;
typedef vector<int> vi;
typedef vector<vi> vvi;
typedef vector<vvi> v3i;
typedef vector<v3i> v4i;
typedef vector<bool> vb;
typedef vector<vb> vvb;
typedef vector<double> vd;
typedef pair<int, int> pi;
typedef pair<ll, ll> pll;
typedef vector<pi> vpi;
#define INF(dt) numeric_limits<dt>::max()
#define NINF(dt) numeric_limits<dt>::min()
#define pb push_back

using namespace std;

struct UF {
    vi sz, pr, tot_csize;
    int n;

    UF(int a_n, const vi& a_tot_csize): sz(a_n, 1), pr(a_n, 0), n(a_n) {
        for(int i = 0; i < n; i++) pr[i] = i;
        tot_csize = a_tot_csize;
    }

    inline int find(int i) {
        return (i == pr[i] ? i : pr[i] = find(pr[i]));
    }

    inline bool conn(int i, int j) {
        return find(i) == find(j);
    }

    inline void unify(int i, int j) {
        int pri = find(i), prj = find(j);
        if(pri == prj) return;

        if(sz[pri] < sz[prj]) {
            pr[pri] = prj;
            sz[prj] += sz[pri];
            tot_csize[prj] += tot_csize[pri];
        } else {
            pr[prj] = pri;
            sz[pri] += sz[prj];
            tot_csize[pri] += tot_csize[prj];
        }
    }
};

void dfs1(int i, int pr, vi& sz, vvi& dfs_edges) {
    sz[i] = 1;
    for(int j : dfs_edges[i]) {
        if(j == pr) continue;
        dfs1(j, i, sz, dfs_edges);
        sz[i] += sz[j];
    }

    int num_e = dfs_edges[i].size();

    for(int ei = 1; ei < num_e; ei++) {
        if(dfs_edges[i][ei] == pr) continue;
        if(sz[dfs_edges[i][ei]] > sz[dfs_edges[i][0]] || dfs_edges[i][0] == pr) swap(dfs_edges[i][0], dfs_edges[i][ei]);
    }
}

void dfs3(int i, int pr, vb& a_ss, const vvi& dfs_edges) {
    a_ss[i] = true;
    for(int j : dfs_edges[i]) {
        if(j == pr) continue;
        dfs3(j, i, a_ss, dfs_edges);
    }
};

vi solve(int n, int m, int a, int b, const vvi& adj) {
    vi res(n, 0);

    // find dfs tree
    stack<pi> s;
    s.push({0, 0});

    vb vis(n, false);

    vvi dfs_edges(n, vi());

    while(!s.empty()) {
        pi cpair = s.top();

        int i = cpair.first, pr = cpair.second;

        s.pop();

        if(vis[i]) continue;
        vis[i] = true;
        if(i != pr) {
            dfs_edges[pr].pb(i);
            dfs_edges[i].pb(pr);
            // cerr << "e: " << i << ", " << pr << endl;
        }

        for(int j : adj[i]) {
            if(j == pr) continue;
            s.push({j, i});
        }
    }

    // find centroid
    vi sz(n, 0);

    
    dfs1(0, 0, sz, dfs_edges);

    int centroid = 0;
    int prev = 0;

    while(true) {
        bool all_small = true;
        for(int i : dfs_edges[centroid]) {
            if(i == prev) continue;
            if(sz[i] > (n >> 1)) all_small = false;
        }
        if(all_small) break;
        prev = centroid;
        centroid = dfs_edges[centroid][0];
    }

    /*
    first check if the dfs tree works on its own
    */
    dfs1(centroid, centroid, sz, dfs_edges);

    for(int sc : dfs_edges[centroid]) {
        assert(sz[sc] <= (n >> 1));
        if(sz[sc] >= a) {
            /*
            bfs from starting child (sc) to generate set a
            */
            queue<int> q;
            q.push(sc);

            int num_marked = 0;

            vb vis(n, false);

            while(!q.empty()) {
                int i = q.front();
                q.pop();

                if(i == centroid) continue;
                if(vis[i]) continue;
                vis[i] = true;
                res[i] = 1;
                num_marked++;
                if(num_marked == a) break;

                for(int j : dfs_edges[i]) {
                    q.push(j);
                }
            }

            while(!q.empty()) {
                q.pop();
            }

            q.push(centroid);
            for(int i = 0; i < n; i++) vis[i] = false;

            num_marked = 0;
            while(!q.empty()) {
                int i = q.front();
                q.pop();

                if(i == sc) continue;
                if(vis[i]) continue;
                vis[i] = true;
                res[i] = 2;
                num_marked++;
                if(num_marked == b) break;

                for(int j : dfs_edges[i]) {
                    q.push(j);
                }
            }

            for(int i = 0; i < n; i++) {
                if(res[i] == 0) res[i] = 3;
            }

            return res;
        }
    }

    /*
    mark each node by the subtree it falls under using floodfill
    */
    int sid = 0; // sid = subtree id

    vi tr_id(n, -1);
    vi tr_root(n, -1);

    for(int i = 0; i < n; i++) vis[i] = false;
    queue<int> q;

    for(int s : dfs_edges[centroid]) {
        q.push(s);
        while(!q.empty()) {
            int i = q.front();
            q.pop();

            if(i == centroid) continue;
            if(vis[i]) continue;
            vis[i] = true;
            tr_id[i] = sid;
            tr_root[i] = s;

            for(int j : dfs_edges[i]) {
                q.push(j);
            }
        }

        sid++;
    }

    int n_child = sid;

    /*
    a function that floods the entire subtree with some value

    a_ss is "A-superset", not anything else >:(
    */
    vb a_ss(n, false);

    /*
    simulate centroid removal
    */
    vi tot_csize(n_child, 0);

    for(int i = 0; i < n_child; i++) {
        tot_csize[i] = sz[dfs_edges[centroid][i]];
    }

    UF uf(n_child, tot_csize);

    vvi tree_comp_adj(n_child, vi());

    for(int i = 0; i < n; i++) {
        if(i == centroid) continue;
        for(int j : adj[i]) {
            if(j == centroid) continue;
            int ci = tr_id[i], cj = tr_id[j];
            if(!uf.conn(ci, cj)) {
                uf.unify(ci, cj);
                tree_comp_adj[ci].pb(cj);
                tree_comp_adj[cj].pb(ci);
            }
        }
    }

    /*
    try to generate set A using one of the components found by union find
    */

    for(int cid = 0; cid < n_child; cid++) {
        if(uf.tot_csize[uf.find(cid)] < a) continue;

        // generate a partition

        // set A
        int cur_num_elem = 0;

        vb vis(n_child, false);
        queue<int> q;
        q.push(cid);

        while(!q.empty()) {
            int i = q.front();
            q.pop();

            if(vis[i]) continue;
            vis[i] = true;
            dfs3(dfs_edges[centroid][i], centroid, a_ss, dfs_edges);
            cur_num_elem += sz[dfs_edges[centroid][i]];
            if(cur_num_elem >= a) break;

            for(int j : tree_comp_adj[i]) {
                q.push(j);
            }
        }

        /// now that i have a subgraph, i can use that to generate set A
        while(!q.empty()) q.pop();

        vb vis2(n, false);
        q.push(dfs_edges[centroid][cid]);
        int num_marked = 0;

        while(!q.empty()) {
            int i = q.front();
            q.pop();

            if(!a_ss[i]) continue;
            if(i == centroid) continue;
            if(vis2[i]) continue;
            vis2[i] = true;
            res[i] = 1;
            num_marked++;
            if(num_marked == a) break;

            for(int j : adj[i]) {
                q.push(j);
            }
        }

        while(!q.empty()) q.pop();

        num_marked = 0;

        for(int i = 0; i < n; i++) vis2[i] = false;
        q.push(centroid);

        while(!q.empty()) {
            int i = q.front();
            q.pop();

            if(res[i] != 0) continue;
            if(vis2[i]) continue;
            vis2[i] = true;
            res[i] = 2;
            num_marked++;
            if(num_marked == b) break;

            for(int j : adj[i]) {
                q.push(j);
            }
        }

        for(int i = 0; i < n; i++) {
            if(res[i] == 0) res[i] = 3;
        }

        return res;
    }

    return res;
}

vector<int> find_split(int n, int a, int b, int c, vector<int> p, vector<int> q) {
	vpi ordering = {{a, 1}, {b, 2}, {c, 3}};

    sort(ordering.begin(), ordering.end());

    int m = p.size();

    vvi adj(n, vi());

    for(int i = 0; i < m; i++) {
        adj[p[i]].pb(q[i]);
        adj[q[i]].pb(p[i]);
    }

    vi sol = solve(n, m, ordering[0].first, ordering[1].first, adj);

    for(int i = 0; i < n; i++) {
        if(sol[i] == 0) continue;
        sol[i] = ordering[sol[i] - 1].second;
    }
	return sol;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...