제출 #1066234

#제출 시각아이디문제언어결과실행 시간메모리
1066234arbuzickSplit the Attractions (IOI19_split)C++17
100 / 100
111 ms25020 KiB
#include "split.h"

#include <bits/stdc++.h>

using namespace std;

void build_tree(int v, vector<vector<int>> &g, vector<int> &used, vector<int> &sz, vector<vector<int>> &g_tree) {
    used[v] = 1;
    sz[v] = 1;
    for (auto u : g[v]) {
        if (!used[u]) {
            g_tree[v].push_back(u);
            g_tree[u].push_back(v);
            build_tree(u, g, used, sz, g_tree);
            sz[v] += sz[u];
        }
    }
}

void calc_tin(int v, int pr, vector<vector<int>> &g_tree, int &t, vector<int> &tin) {
    tin[v] = t++;
    for (auto u : g_tree[v]) {
        if (u != pr) {
            calc_tin(u, v, g_tree, t, tin);
        }
    }
}

void calc_tup(int v, int pr, vector<vector<int>> &g_tree, vector<vector<int>> &g, vector<int> &tin, vector<int> &tup) {
    tup[v] = tin[v];
    for (auto u : g_tree[v]) {
        if (u != pr) {
            calc_tup(u, v, g_tree, g, tin, tup);
            tup[v] = min(tup[v], tup[u]);
        }
    }
    for (auto u : g[v]) {
        if (u != pr) {
            tup[v] = min(tup[v], tin[u]);
        }
    }
}

vector<int> find_split(int n, int a, int b, int c, vector<int> p, vector<int> q) {
    vector<int> col(3);
    col[0] = 1;
    col[1] = 2;
    col[2] = 3;
    if (a > b) {
        swap(a, b);
        swap(col[0], col[1]);
    }
    if (b > c) {
        swap(b, c);
        swap(col[1], col[2]);
    }
    if (a > b) {
        swap(a, b);
        swap(col[0], col[1]);
    }
    int m = p.size();
    vector<vector<int>> g(n);
    for (int i = 0; i < m; ++i) {
        g[p[i]].push_back(q[i]);
        g[q[i]].push_back(p[i]);
    }
    vector<int> used(n);
    vector<int> sz(n);
    vector<vector<int>> g_tree(n);
    build_tree(0, g, used, sz, g_tree);
    vector<int> prv(n, -1);
    for (int i = 0; i < n; ++i) {
        int pr = -1;
        for (auto j : g_tree[i]) {
            if (sz[j] > sz[i]) {
                pr = j;
            }
        }
        prv[i] = pr;
        if (sz[i] >= a && n - sz[i] >= b) {
            vector<int> res(n, col[2]);
            queue<int> q;
            q.push(i);
            int cnt_a = 0;
            while (cnt_a < a) {
                int v = q.front();
                q.pop();
                res[v] = col[0];
                cnt_a++;
                for (auto u : g_tree[v]) {
                    if (res[u] == col[2] && u != pr) {
                        q.push(u);
                    }
                }
            }
            int cnt_b = 0;
            q = queue<int>();
            q.push(pr);
            while (cnt_b < b) {
                int v = q.front();
                q.pop();
                res[v] = col[1];
                cnt_b++;
                for (auto u : g_tree[v]) {
                    if (res[u] == col[2] && u != i) {
                        q.push(u);
                    }
                }
            }
            return res;
        } else if (sz[i] >= b && n - sz[i] >= a) {
            vector<int> res(n, col[2]);
            queue<int> q;
            q.push(pr);
            int cnt_a = 0;
            while (cnt_a < a) {
                int v = q.front();
                q.pop();
                res[v] = col[0];
                cnt_a++;
                for (auto u : g_tree[v]) {
                    if (res[u] == col[2] && u != i) {
                        q.push(u);
                    }
                }
            }
            int cnt_b = 0;
            q = queue<int>();
            q.push(i);
            while (cnt_b < b) {
                int v = q.front();
                q.pop();
                res[v] = col[1];
                cnt_b++;
                for (auto u : g_tree[v]) {
                    if (res[u] == col[2] && u != pr) {
                        q.push(u);
                    }
                }
            }
            return res;
        }
    }
    vector<int> tin(n), tup(n);
    int t = 1;
    calc_tin(0, -1, g_tree, t, tin);
    calc_tup(0, -1, g_tree, g, tin, tup);
    int centr = -1;
    for (int i = 0; i < n; ++i) {
        bool check = true;
        for (auto j : g_tree[i]) {
            if (sz[j] < sz[i] && sz[j] * 2 > n) {
                check = false;
            }
        }
        if ((n - sz[i]) * 2 > n) {
            check = false;
        }
        if (check) {
            centr = i;
            break;
        }
    }
    if (prv[centr] != -1) {
        vector<int> part_a = {prv[centr]};
        int sz_a = n - sz[centr];
        for (auto v : g_tree[centr]) {
            if (v != prv[centr] && tup[v] < tin[centr]) {
                sz_a += sz[v];
                part_a.push_back(v);
                if (sz_a >= a) {
                    break;
                }
            }
        }
        if (sz_a >= a) {
            vector<int> part_a_used(n);
            for (auto root : part_a) {
                queue<int> q;
                q.push(root);
                while (!q.empty()) {
                    int v = q.front();
                    q.pop();
                    part_a_used[v] = true;
                    for (auto u : g_tree[v]) {
                        if (!part_a_used[u] && u != centr) {
                            q.push(u);
                        }
                    }
                }
            }
            vector<int> res(n, col[2]);
            int cnt_a = 0;
            queue<int> q;
            q.push(prv[centr]);
            while (cnt_a < a) {
                int v = q.front();
                q.pop();
                if (res[v] == col[0]) {
                    continue;
                }
                res[v] = col[0];
                cnt_a++;
                for (auto u : g[v]) {
                    if (res[u] == col[2] && part_a_used[u]) {
                        q.push(u);
                    }
                }
            }
            int cnt_b = 0;
            q = queue<int>();
            q.push(centr);
            while (cnt_b < b) {
                int v = q.front();
                q.pop();
                if (res[v] == col[1]) {
                    continue;
                }
                res[v] = col[1];
                cnt_b++;
                for (auto u : g[v]) {
                    if (res[u] == col[2]) {
                        q.push(u);
                    }
                }
            }
            return res;
        }
    }
    return vector<int>(n, 0);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...