제출 #1233850

#제출 시각아이디문제언어결과실행 시간메모리
1233850jasonicSplit the Attractions (IOI19_split)C++20
100 / 100
75 ms20296 KiB
#include <bits/stdc++.h>
#include "split.h"
using namespace std;

#define ll long long
#define fastIO cin.tie(0); ios::sync_with_stdio(false)

int labels[3] = {1, 2, 3};
int n, m, a, b, c;

vector<vector<int>> bigadjlist;
vector<vector<int>> adjlist;
int sz[200005];

void calc_sizes(int v, int p = -1) {
    sz[v] = 1;

    for(auto i : adjlist[v]) if(i != p) {
        calc_sizes(i, v);
        sz[v] += sz[i];
    }
};

int get_centroid(int v, int p = -1) {
    for(auto i : adjlist[v]) if(i != p) {
        if(sz[i] * 2 > n) return get_centroid(i, v);
    }
    return v;
}

int need = 0;
void dfs_fill(vector<vector<int>> &adj, vector<int> &arr, int val, int v, int p = -1) {
    if(need == 0) return;
    arr[v] = val;
    need--;
    if(need == 0) return;

    for(auto i : adj[v]) if(i != p) if(arr[i] == -1) {
        dfs_fill(adj, arr, val, i, v);
        if(need == 0) return;
    }
}

vector<int> find_split(int N, int A, int B, int C, vector<int> p, vector<int> q) {
    n = N;
    m = p.size();
    a=A, b=B, c=C;

    bigadjlist = vector<vector<int>>(n);
    adjlist = vector<vector<int>>(n);

    for(int i = 0; i < m; i++) {
        bigadjlist[p[i]].push_back(q[i]);
        bigadjlist[q[i]].push_back(p[i]);
    }

    // assign the correct things
    if(c < a) {
        swap(a, c);
        swap(labels[0], labels[2]);
    }

    if(b < a) {
        swap(a, b);
        swap(labels[0], labels[1]);
    }

    if(c < b) {
        swap(c, b);
        swap(labels[1], labels[2]);
    }

    // dfs tree to find correct edges
    vector<int> vis(n, false);
    stack<int> dfs;
    dfs.push(0);
    vis[0] = true;
    while(!dfs.empty()) {
        int curr = dfs.top(); dfs.pop();
        for(auto i : bigadjlist[curr]) if(!vis[i]) {
            adjlist[curr].push_back(i);
            adjlist[i].push_back(curr);
            dfs.push(i);
            vis[i] = true;
        }
    }

    calc_sizes(0);
    int centroid = get_centroid(0);
    calc_sizes(centroid);

    // check if the dfs tree is of any use
    for(auto i : adjlist[centroid]) {
        if(a <= sz[i]) {
            vector<int> ans(n, -1);

            need = a;
            dfs_fill(adjlist, ans, labels[0], i, centroid);

            need = b;
            dfs_fill(adjlist, ans, labels[1], centroid, i);

            for(auto &i : ans) if(i == -1) i = labels[2];

            return ans;
        }
    }

    // otherwise, we can combine subtrees that have an edge connecting them and check if that induces a connected component of size >= a

    struct DSUTree{
        vector<int> par, sz;
        
        DSUTree(int n) {
            par = sz = vector<int>(n, 1);
            for(int i = 0; i < n; i++) par[i] = i;
        }

        int parent(int n) {
            return par[n]==n?n:par[n]=parent(par[n]);
        }

        int size(int n) {
            return sz[parent(n)];
        }

        bool together(int a, int b) {
            return parent(a) == parent(b);
        }

        void merge(int a, int b) {
            a = parent(a);
            b = parent(b);
            if(a == b) return;
            if(!(sz[a] > sz[b])) swap(a, b);
            sz[a] += sz[b];
            par[b] = a;
        }
    };

    DSUTree dsu(n);
    for(int i = 0; i < n; i++) {
        if(i == centroid) continue;
        for(auto j : adjlist[i]) if(j != centroid) dsu.merge(i, j);
    }

    for(int i = 0; i < m; i++) {
        int x = p[i], y = q[i];
        if(x == centroid || y == centroid) continue;

        dsu.merge(x, y);
        if(max(dsu.size(x), dsu.size(y)) >= a) {
            vector<int> ans(n, 0);

            for(int i = 0; i < n; i++) {
                if(dsu.together(i, x)) {
                    ans[i] = -1;
                }
            }

            need = a;
            dfs_fill(bigadjlist, ans, labels[0], x);

            for(int i = 0; i < n; i++) if(ans[i] != labels[0]) ans[i] = -1;

            need = b;
            dfs_fill(adjlist, ans, labels[1], centroid);

            for(auto &i : ans) if(i == -1) i = labels[2];

            return ans;
        }
    }

    vector<int> ans = vector<int>(n, 0);

    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...