Submission #1241773

#TimeUsernameProblemLanguageResultExecution timeMemory
1241773BoasLongest Trip (IOI23_longesttrip)C++17
100 / 100
92 ms568 KiB
#include "longesttrip.h"

#include <bits/stdc++.h>
using namespace std;

typedef vector<int> vi;
typedef set<int> si;
#define pb push_back
#define sz(x) (int)(x).size()
#define loop(n, i) for (int i = 0; i < (n); i++)
#define rev(n, i) for (int i = (n) - 1; i >= 0; i--)
#define ALL(x) begin(x), end(x)

map<pair<vi, vi>, bool> dp;

void myAssert(bool b)
{
    if (!b)
        throw;
}

/*void protocol_violation(string s)
{
    cerr << s << endl;
    throw;
}

int total_call_counter = 0;
int landmark_counter = 0;
inline constexpr int maxNumberOfCalls = 32640;
inline constexpr int maxTotalNumberOfCalls = 150000;
inline constexpr int maxTotalNumberOfLandmarksInCalls = 1500000;*/

bool hasAndSet(pair<vi, vi> b)
{
    return (dp.count(b) && dp[b]);
}

vi longest_trip(int N, int D)
{
    dp.clear();

    // int call_counter = 0;
    // std::vector<bool> present(N);

    auto conn = [&](vi A, vi B) -> bool
    {
        if (dp.count({A, B}))
            return dp[{A, B}];
        if (dp.count({B, A}))
            return dp[{B, A}];
        if (sz(A) == sz(B) && sz(B) == 1)
        {
            loop(N, i)
            {
                if (dp.count({{i}, A}) && dp.count({{i}, B}))
                {
                    if (!dp[{{i}, A}] && !dp[{{i}, B}])
                        return 1;
                }
            }
        }
        else
        {
            for (int a : A)
                for (int b : B)
                    if (hasAndSet({{a}, {b}}) || hasAndSet({{b}, {a}}))
                        return 1;
        }
        /*
        ++call_counter;
        ++total_call_counter;
        if (call_counter > maxNumberOfCalls || total_call_counter > maxTotalNumberOfCalls)
        {
            protocol_violation("too many calls");
        }

        int nA = sz(A), nB = sz(B);
        landmark_counter += nA + nB;
        if (landmark_counter > maxTotalNumberOfLandmarksInCalls)
        {
            protocol_violation("too many elements");
        }

        if (nA == 0 || nB == 0)
        {
            protocol_violation("invalid array");
        }
        for (int i = 0; i < nA; ++i)
        {
            if (A[i] < 0 || N <= A[i])
            {
                protocol_violation("invalid array");
            }
            if (present[A[i]])
            {
                protocol_violation("invalid array");
            }
            present[A[i]] = true;
        }
        for (int i = 0; i < nA; ++i)
        {
            present[A[i]] = false;
        }
        for (int i = 0; i < nB; ++i)
        {
            if (B[i] < 0 || N <= B[i])
            {
                protocol_violation("invalid array");
            }
            if (present[B[i]])
            {
                protocol_violation("invalid array");
            }
            present[B[i]] = true;
        }
        for (int i = 0; i < nB; ++i)
        {
            present[B[i]] = false;
        }

        for (int i = 0; i < nA; ++i)
        {
            for (int j = 0; j < nB; ++j)
            {
                if (A[i] == B[j])
                {
                    protocol_violation("non-disjoint arrays");
                }
            }
        }*/
        dp[{A, B}] = are_connected(A, B);
        dp[{B, A}] = dp[{A, B}];
        return dp[{A, B}];
    };

    deque<int> res = {0};
    vi res2;
    for (int i = 1; i < N; i++)
    {
        if (conn({i}, {res.back()}))
        {
            res.pb(i);
        }
        else if (res2.empty())
        {
            res2 = {i};
        }
        else if (!conn({res.back()}, {res2.back()})) // wordt opgeslagen van de vorige keer
        {
            res2.pb(i);
        }
        else if (conn({i}, {res2.back()}))
        {
            res2.pb(i);
        }
        else
        {
            if (i + 2 < N)
            {
                if (conn({i}, {i + 1}) && conn({i}, {i + 2}))
                {
                    rev(sz(res2), i) res.pb(res2[i]);
                    res2 = {i + 1, i, i + 2};
                }
                else if (!conn({i}, {i + 1}) && !conn({i}, {i + 2}))
                {
                    res.pb(i + 1);
                    res.pb(i + 2);
                    rev(sz(res2), i) res.pb(res2[i]);
                    res2 = {i};
                }
                else if (conn({i}, {i + 1}))
                {
                    res.pb(i + 2);
                    rev(sz(res2), i) res.pb(res2[i]);
                    res2 = {i, i + 1};
                }
                else // (conn({i}, {i + 2}))
                {
                    res.pb(i + 1);
                    rev(sz(res2), i) res.pb(res2[i]);
                    res2 = {i, i + 2};
                }
                i += 2;
            }
            else
            {
                rev(sz(res2), i) res.pb(res2[i]);
                res2 = {i};
            }
        }
    }
    vi ret = vi(ALL(res));
    if (sz(res2) && conn(vi(ALL(res2)), ret))
    {
        // 4 queries voor end merging, anders beide cyclisch

        if (conn({res.back()}, {res2[0]}))
        {
            for (int v : res2)
                res.pb(v);
            return vi(ALL(res));
        }
        if (conn({res2.back()}, {res[0]}))
        {
            for (int v : res)
                res2.pb(v);
            return res2;
        }
        if (conn({res2.back()}, {res.back()}))
        {
            rev(sz(res2), i) res.pb(res2[i]);
            return vi(ALL(res));
        }
        if (conn({res2[0]}, {res[0]}))
        {
            for (int v : res2)
                res.push_front(v);
            return vi(ALL(res));
        }

        int ix1, ix2;
        {
            int lo = 0, hi = sz(res2) - 1;
            while (hi > lo)
            {
                int m = (hi + lo) / 2;
                vi q;
                loop(m + 1, k) q.pb(res2[k]);
                if (conn(q, ret))
                    hi = m;
                else
                    lo = m + 1;
            }
            ix2 = lo;
        }
        {
            int lo = 0, hi = sz(res) - 1;
            while (hi > lo)
            {
                int m = (hi + lo) / 2;
                vi q;
                loop(m + 1, k) q.pb(res[k]);
                if (conn(q, {res2[ix2]}))
                    hi = m;
                else
                    lo = m + 1;
            }
            ix1 = lo;
        }
        vi path;
        for (int i = ix1 + 1; i < sz(res); i++)
            path.pb(res[i]);
        for (int i = 0; i <= ix1; i++)
            path.pb(res[i]);
        for (int i = ix2; i < sz(res2); i++)
            path.pb(res2[i]);
        for (int i = 0; i < ix2; i++)
            path.pb(res2[i]);
        myAssert(size(path) == N);
        for (int i = 0; i + 1 < sz(path); i++)
            myAssert(conn({path[i]}, {path[i + 1]}));
        return path;
    }
    for (int i = 0; i + 1 < sz(res2); i++)
        myAssert(conn({res2[i]}, {res2[i + 1]}));
    if (sz(res2) > sz(res))
        return res2;
    for (int i = 0; i + 1 < sz(res); i++)
        myAssert(conn({ret[i]}, {ret[i + 1]}));
    return ret;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...