Submission #144345

#TimeUsernameProblemLanguageResultExecution timeMemory
144345johuthaSplit the Attractions (IOI19_split)C++14
40 / 100
168 ms22016 KiB
#include <iostream>
#include <vector>
#include "split.h"
#include <queue>
#include <algorithm>

using namespace std;

struct treesolver
{
    vector<vector<int>> adjlist;
    vector<pair<int,int>> r;
    vector<int> subtc;
    vector<int> sol;
    int n;

    treesolver(int in)
    {
        n = in;
        adjlist.resize(n);
        subtc.resize(n);
        sol.resize(n);
    }

    int cntdfs(int curr, int par)
    {
        int ssum = 1;

        for (int next : adjlist[curr])
        {
            if (next == par) continue;
            ssum += cntdfs(next, curr);
        }
        return subtc[curr] = ssum;
    }

    pair<int,int> finddfs(int curr, int par)
    {
        for (int next : adjlist[curr])
        {
            if (next == par) continue;
            if (subtc[next] > n / 2)
            {
                return finddfs(next, curr);
            }
        }
        return {curr, par};
    }

    pair<int,int> findcentroid()
    {
        cntdfs(0, -1);
        return finddfs(0, -1);
    }

    void paint(int st, int par, int color, int num)
    {
        queue<int> q;
        q.push(st);

        while (!q.empty())
        {
            int curr = q.front();
            q.pop();
            if (curr == par || sol[curr] != 0) continue;
            sol[curr] = color;
            num--;
            if (num < 1) break;
            for (int next : adjlist[curr])
            {
                q.push(next);
            }
        }
    }

    void solve()
    {
        auto cent = findcentroid();
        vector<int> nb;
        vector<int> nbc;

        for (int next : adjlist[cent.first])
        {
            nb.push_back(next);
            if (cent.second == next)
            {
                nbc.push_back(n - subtc[cent.first]);
            }
            else nbc.push_back(subtc[next]);
        }

        bool found = false;

        for (int i = 0; i < (int)nb.size(); i++)
        {
            if (nbc[i] >= r[0].first)
            {
                paint(nb[i], cent.first, r[0].second, r[0].first);
                found = true;
                break;
            }
        }

        if (!found) return;

        paint(cent.first, -1, r[1].second, r[1].first);

        for (int i = 0; i < n; i++)
        {
            if (sol[i] == 0) sol[i] = r[2].second;
        }
    }
};

vector<bool> visited;
vector<vector<int>> adjlist;

void dfs(int curr, treesolver &ts)
{
    visited[curr] = true;

    for (int next : adjlist[curr])
    {
        if (visited[next]) continue;
        ts.adjlist[curr].push_back(next);
        ts.adjlist[next].push_back(curr);
        dfs(next, ts);
    }
}

vector<int> find_split(int n, int a, int b, int c, vector<int> p, vector<int> q)
{
    vector<pair<int,int>> r = {{a, 1}, {b, 2}, {c, 3}};
    sort(r.begin(), r.end());
    treesolver ts(n);
    ts.r = r;

    adjlist.resize(n);
    visited.resize(n, false);

    for (int i = 0; i < (int)q.size(); i++)
    {
        adjlist[p[i]].push_back(q[i]);
        adjlist[q[i]].push_back(p[i]);
    }

    dfs(0, ts);

    ts.solve();
    return ts.sol;

    return {};
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...