제출 #528671

#제출 시각아이디문제언어결과실행 시간메모리
528671jenny00513Split the Attractions (IOI19_split)C++17
22 / 100
167 ms32764 KiB
#include "bits/stdc++.h"
#include "split.h"
using namespace std;

#define F first
#define S second
#define rep(i, n) for (int i = 1; i <= n; i++)
#define all(x) (x).begin(), (x).end()
#define comp(x) sort(all((x))); (x).erase(unique(all((x))), (x).end())
#define sz(x) (int)((x).size())
#define sq(x) (x) * (x)
#define srt(x) sort(all(x))
#define pb push_back
#define eb emplace_back

typedef long long ll;
typedef vector<int> vi;
typedef vector<ll> vl;
typedef pair<int, int> pi;
typedef pair<ll, ll> pl;

#define yes cout << "Yes\n"
#define no cout << "No\n"
#define imp cout << "-1\n"
#define el cout << "\n"

const int MAX = 1e5 + 5;
const int LOG = 20;
const int INF = 1e9;
const ll LINF = 1e18;
const int MOD = 1e9 + 7;
const int dy[8] = { -1, 0, 1, 0, -1, 1, 1, -1 };
const int dx[8] = { 0, 1, 0, -1, 1, 1, -1, -1 };

template<typename ...Args>
void read(Args&... args) { (cin >> ... >> args); }

struct Edge
{
    int u, v;
};

class DisjointSet {
public:
    
    DisjointSet(int n): n(n), par(n + 5), s(n + 5) {
        init();
    }
    int n; vector<int> par, s;
    
    void init(void)
    {
        iota(all(par), 0);
        fill(all(s), 1);
    }
    
    int Find(int x)
    {
        if (x == par[x]) return x;
        return par[x] = Find(par[x]);
    }
    
    bool Union(int p, int q)
    {
        p = Find(p); q = Find(q); if (p == q) return false;
        par[q] = p; s[p] += s[q]; s[q] = 0; return true;
    }
    
    bool same(int p, int q)
    {
        p = Find(p); q = Find(q);
        return (p == q);
    }

    int Size(int p)
    {
        p = Find(p);
        return s[p];
    }
};

const int NODE = 200005;
const int EDGE = 400005;

int n, m; vector<pi> abc(3);
// vi p(EDGE), q(EDGE);
vector<Edge> spantree, more;
DisjointSet ds(NODE);
vi adj[NODE], addadj[NODE], graph[NODE], szz(NODE);
vi ans;

int getSize(int node, int par)
{
    szz[node] = 1;
 
    for (auto &each : adj[node])
    {
        if (each == par) continue;
        szz[node] += getSize(each, node);
    }
 
    return szz[node];
}
 
int getCent(int node, int par, int half)
{
    for (auto &each : adj[node])
    {
        if (each == par) continue;
        if (szz[each] > half) return getCent(each, node, half);
    }
 
    return node;
}

void ret(vi ans)
{
    for (auto &each : ans) cout << each << ' ';
    cout << '\n'; exit(0);
}

vi vst(NODE);
int cnt;

void dfsa(int node)
{
    // cout << "node: " << node;
    vst[node] = true;
    cnt--; 
    if (cnt < 0) return;
    ans[node] = abc[0].S;
    // cout << " - " << ans[node] << '\n';

    for (auto &each : addadj[node])
    {
        if (vst[each]) continue;
        dfsa(each);
    }
}

void dfsb(int node)
{
    cnt--; 
    if (cnt < 0) return;
    ans[node] = abc[1].S;
    // cout << "b: " << node << '\n';

    for (auto &each : graph[node])
    {
        if (ans[each]) continue;
        dfsb(each);
    }
}

vector<int> find_split(int N, int a, int b, int c, vector<int> p, vector<int> q)
{
    n = N; m = sz(p);
    ans.resize(n);
    
    abc[0].F = a; abc[1].F = b; abc[2].F = c;
    for (int i = 0; i < 3; i++) abc[i].S = i + 1;

    srt(abc);

    for (int i = 0; i < m; i++)
    {
        graph[p[i]].pb(q[i]); graph[q[i]].pb(p[i]);
        
        if (ds.Union(p[i], q[i])) 
        {
            spantree.pb((Edge){ p[i], q[i] });
            adj[p[i]].pb(q[i]); adj[q[i]].pb(p[i]);
        }

        else more.pb((Edge){ p[i], q[i] });
    }

    int limi = getSize(0, 0) / 2;
    int cent = getCent(0, 0, limi);

    // cout << "limit: " << limi << '\n';
    // cout << "cent: " << cent << '\n';

    for (int i = 0; i < n; i++)
    {
        for (auto &each : adj[i])
        {
            if (i == cent || each == cent) continue;
            addadj[i].pb(each); addadj[each].pb(i);
        }
    }

    ds.init();

    for (auto &[u, v] : spantree)
    {
        if (u == cent || v == cent) continue;
        ds.Union(u, v);
    }

    int anode = -1;

    for (int i = 0; i < n; i++)
    {
        if (ds.Size(i) >= abc[0].F)
        {
            anode = ds.Find(i);
            break;
        }
    }

    if (anode == -1)
    {
        for (auto &[u, v] : more)
        {
            if (u == cent || v == cent) continue;
            
            if (ds.Union(u, v))
            {
                addadj[u].pb(v);
                addadj[v].pb(u);
            }
            
            if (ds.Size(u) >= abc[0].F)
            {
                anode = ds.Find(u);
                break;
            }
        }
    }

    if (anode == -1) return ans;

    // cout << "anode: " << anode << '\n';

    cnt = abc[0].F; vst[cent] = true;
    dfsa(anode);

    cnt = abc[1].F;
    // cout << "bnum: " << cnt << '\n';
    dfsb(cent);
    
    for (int i = 0; i < n; i++) if (!ans[i]) ans[i] = abc[2].S;
    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...