Submission #832099

#TimeUsernameProblemLanguageResultExecution timeMemory
832099benjaminkleyn버섯 세기 (IOI20_mushrooms)C++17
75.33 / 100
1475 ms26060 KiB
#include "mushrooms.h"
#include <bits/stdc++.h>
using namespace std;
typedef long double ld;

int f(string x)
{
    int cnt = 0;
    for (int i = 0; i < x.size() - 1; i++)
        cnt += (x[i] != x[i+1]);
    return cnt;
}

struct Permutation
{
    vector<int> p;
    Permutation() {}
    string operator()(const string &s)
    {
        string res;
        for (int i : p)
            res.push_back(s[i]);
        return res;
    }
    vector<int> operator()(const vector<int> &v)
    {
        vector<int> res;
        for (int i : p)
            res.push_back(v[i]);
        return res;
    }
};

vector<Permutation> p1, p2;

void compute_permutations()
{
    p1.push_back(Permutation());
    p1[0].p = {0, 1, 2, 3, 4, 5};
    for (int i = 1; i < 720; i++)
    {
        p1.push_back(p1.back());
        next_permutation(p1.back().p.begin(), p1.back().p.end());
    }
    p2.push_back(Permutation());
    p2[0].p = {0, 1, 2, 3, 4, 5, 6, 7, 8};
    for (int i = 1; i < 362880; i++)
    {
        p2.push_back(p2.back());
        next_permutation(p2.back().p.begin(), p2.back().p.end());
    }
}

struct TreeNode
{
    vector<TreeNode*> children;
    bool is_end = false;
    string val;
    Permutation p;
    string find(vector<int> indices)
    {
        return is_end ? val : children[use_machine(p(indices))]->find(indices);
    }
};

void build1_recursive(TreeNode *cur, vector<string> &possible)
{
    if (possible.size() == 0)
        return;
    if (possible.size() == 1)
    {
        cur->is_end = true;
        cur->val = *possible.begin();
        return;
    }
    ld bestI = 0.0;
    for (Permutation &p : p1)
    {
        vector<int> cnt(10);
        for (string &s : possible)
            cnt[f(p(s))]++;
        ld I = 0.0;
        for (int i = 0; i < 10; i++)
            if (cnt[i] > 0)
                I += cnt[i] * log2(((ld)possible.size()) / ((ld)cnt[i]));
        if (I > bestI)
            bestI = I, cur->p = p;
    }
    cur->children = vector<TreeNode*>();
    for (int i = 0; i < 10; i++)
    {
        cur->children.push_back(new TreeNode());
        vector<string> newpossible;
        for (string &s : possible)
            if (f(cur->p(s)) == i)
                newpossible.push_back(s);
        build1_recursive(cur->children[i], newpossible);
    }
}

void build2_recursive(TreeNode *cur, vector<string> &possible)
{
    if (possible.size() == 0)
        return;
    if (possible.size() == 1)
    {
        cur->is_end = true;
        cur->val = *possible.begin();
        return;
    }
    ld bestI = 0.0;
    for (Permutation &p : p2)
    {
        vector<int> cnt(10);
        for (string &s : possible)
            cnt[f(p(s))]++;
        ld I = 0.0;
        for (int i = 0; i < 10; i++)
            if (cnt[i] > 0)
                I += cnt[i] * log2(((ld)possible.size()) / ((ld)cnt[i])) / ((ld)possible.size());
        if (I > bestI)
            bestI = I, cur->p = p;
    }
    cur->children = vector<TreeNode*>();
    for (int i = 0; i < 10; i++)
    {
        cur->children.push_back(new TreeNode());
        vector<string> newpossible;
        for (string &s : possible)
            if (f(cur->p(s)) == i)
                newpossible.push_back(s);
        build2_recursive(cur->children[i], newpossible);
    }
}

TreeNode *root1, *root2;
void build_trees()
{
    compute_permutations();
    root1 = new TreeNode();
    root2 = new TreeNode();
    vector<string> possible1;
    vector<string> possible2;
    for (int i = 0; i < 32; i++)
    {
        string s;
        for (int j = 0; j < 5; j++)
            if (i & (1 << j))
                s.push_back('A');
            else
                s.push_back('B');
        possible1.push_back("A" + s);
        possible2.push_back("AAAB" + s);
    }
    build1_recursive(root1, possible1);
    build2_recursive(root2, possible2);
}

int count_mushrooms(int n)
{
    if (n < 6)
    {
        int cnt = 1;
        for (int i = 1; i < n; i++)
            cnt += 1 - use_machine({0, i});
        return cnt;
    }
    build_trees();

    vector<int> A, B;
    A.push_back(0);

    string s;
    int i = 1, cnt = 1;
    for (; i + 4 < n && B.size() == 0; i += 5)
    {
        s = root1->find({0, i, i+1, i+2, i+3, i+4});
        for (int j = 0; j < 5; j++)
            if (s[j+1] == 'A')
            {
                A.push_back(i+j);
                cnt++;
            }
            else
                B.push_back(i+j);
    }

    for (; i + 4 < n && i < 100; i += 5)
    {
        if (A.size() > B.size())
        {
            s = root2->find({A[0], A[1], A[2], B[0], i, i+1, i+2, i+3, i+4});
            for (int j = 0; j < 5; j++)
                if (s[j+4] == 'A')
                {
                    A.push_back(i+j);
                    cnt++;
                }
                else
                    B.push_back(i+j);
        }
        else
        {
            s = root2->find({B[0], B[1], B[2], A[0], i, i+1, i+2, i+3, i+4});
            for (int j = 0; j < 5; j++)
                if (s[j+4] == 'A')
                    B.push_back(i+j);
                else
                {
                    A.push_back(i+j);
                    cnt++;
                }
        }
    }

    while (i < n)
    {
        vector<int> query;
        if (A.size() >= B.size())
        {
            for (int j : A)
            {
                query.push_back(j);
                query.push_back(i++);
                if (i >= n)
                    break;
            }
            int x = use_machine(query);
            cnt += query.size() / 2 - (x + 1) / 2;
            if (x % 2)
                B.push_back(query.back());
            else
                A.push_back(query.back());
        }
        else
        {
            for (int j : B)
            {
                query.push_back(j);
                query.push_back(i++);
                if (i >= n)
                    break;
            }
            int x = use_machine(query);
            cnt += (x + 1) / 2;
            if (x % 2)
                A.push_back(query.back());
            else
                B.push_back(query.back());
        }
    }
    return cnt;
}

Compilation message (stderr)

mushrooms.cpp: In function 'int f(std::string)':
mushrooms.cpp:9:23: warning: comparison of integer expressions of different signedness: 'int' and 'std::__cxx11::basic_string<char>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
    9 |     for (int i = 0; i < x.size() - 1; i++)
      |                     ~~^~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...