답안 #862694

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
862694 2023-10-18T21:14:34 Z Youssif_Elkadi Logičari (COCI21_logicari) C++17
10 / 110
502 ms 524288 KB
#include <bits/stdc++.h>
using namespace std;
const long long N = 1e5 + 2, mod = 1e9 + 7;
vector<int> adj[N];
vector<int> cy;
int dp[N][4], dp2[N][5][5];
bool vis[N];
int par[N];
int n, hi;
void cyc(int u)
{
    vis[u] = 1;
    for (auto v : adj[u])
    {
        if (v == par[u])
            continue;
        par[v] = u;
        if (vis[v])
            hi = v;
        if (hi)
            return;
        cyc(v);
    }
    vis[u] = 0;
}
void calc(int u, int p)
{
    bool flag = 0;
    int cnt[4] = {}, sum[4] = {};
    for (auto v : adj[u])
    {
        if (vis[v] || v == p)
            continue;
        calc(v, u);
        for (int i = 0; i < 4; i++)
            cnt[i] += (dp[v][i] == -1), sum[i] += dp[v][i];
        flag = 1;
    }
    if (!flag)
    {
        dp[u][3] = dp[u][1] = 0;
        dp[u][0] = dp[u][2] = -1;
        return;
    }
    for (int i = 0; i < 4; i++)
        dp[u][i] = mod;
    for (auto v : adj[u])
    {
        if (vis[v] || v == p)
            continue;
        // calc dp[u][0]
        if (dp[v][1] != -1 && cnt[3] - (dp[v][3] == -1) == 0)
            dp[u][0] = min(dp[u][0], dp[v][1] + (sum[3] - dp[v][3]) + 2);
        // calc dp[u][1]
        if (cnt[3] == 0)
            dp[u][1] = sum[3];
        // calc dp[u][2]
        if (dp[v][0] != -1 && cnt[2] - (dp[v][2] == -1) == 0)
            dp[u][2] = min(dp[u][2], dp[v][0] + (sum[2] - dp[v][2]));
        // calc dp[u][3]
        if (cnt[2] == 0)
            dp[u][3] = sum[2];
    }
    for (int i = 0; i < 4; i++)
        dp[u][i] = (dp[u][i] == mod ? -1 : dp[u][i]);
}
/*
dp[u][0]= one dp[v][1] AND all dp[v][3] (source)
dp[u][1]= all dp[v][3] (partner)
dp[u][2]= one dp[v][0] AND all dp[v][2] (parent)
dp[u][3]= all dp[v][2] (nothing)
*/
/*
states:
0: means i am blue
1: means i am sec blue
2: i am after 1 OR 4
3: i am behind 0 OR 4
4: means i am blue and i dont want partner
*/
int solve(int ind, int lst, int st)
{
    if (~dp2[ind][lst][st])
        return dp2[ind][lst][st];
    if (ind == cy.size())
    {
        if (lst == 0 && st == 1)
            return 0;
        if (lst == 1 && st == 2)
            return 0;
        if (lst == 2 && st == 3)
            return 0;
        if (lst == 3 && (st == 0 || st == 4))
            return 0;
        if (lst == 4 && st == 2)
            return 0;
        return mod;
    }
    int ret = mod;
    if (ind == 0)
    {
        // ret = solve(ind + 1, 4, 4) + dp[cy[ind]][0];
        for (int i = 0; i < 5; i++)
        {
            if ((i == 2 || i == 3) && dp[cy[ind]][3] != -1)
                ret = min(ret, solve(ind + 1, i, i) + dp[cy[ind]][3]);
            else if (i == 4 && dp[cy[ind]][0] != -1)
                ret = min(ret, solve(ind + 1, i, i) + dp[cy[ind]][0]);
            else if ((i == 0 || i == 1) && dp[cy[ind]][1] != -1)
                ret = min(ret, solve(ind + 1, i, i) + dp[cy[ind]][1] + (i == 0) * 2);
        }
    }
    else
    {
        if (lst == 0 && dp[cy[ind]][1] != -1)
            ret = solve(ind + 1, 1, st) + dp[cy[ind]][1];
        if (lst == 1 && dp[cy[ind]][3] != -1)
            ret = solve(ind + 1, 2, st) + dp[cy[ind]][3];
        if (lst == 2 && dp[cy[ind]][3] != -1)
            ret = solve(ind + 1, 3, st) + dp[cy[ind]][3];
        if (lst == 3 && dp[cy[ind]][1] != -1)
            ret = min(ret, solve(ind + 1, 0, st) + dp[cy[ind]][1] + 2);
        if (lst == 3 && dp[cy[ind]][0] != -1)
            ret = min(ret, solve(ind + 1, 4, st) + dp[cy[ind]][0]);
        if (lst == 4 && dp[cy[ind]][3] != -1)
            ret = min(ret, solve(ind + 1, 2, st) + dp[cy[ind]][3]);
    }
    return dp2[ind][lst][st] = ret;
}
int main()
{
    memset(dp2, -1, sizeof dp2);
    cin >> n;
    for (int i = 0; i < n; i++)
    {
        int x, y;
        cin >> x >> y;
        adj[x].push_back(y);
        adj[y].push_back(x);
    }
    cyc(1);
    int tmp = hi;
    while (true)
    {
        cy.push_back(tmp);
        vis[tmp] = 1;
        tmp = par[tmp];
        if (tmp == hi)
            break;
    }
    for (int i = 1; i <= n; i++)
        vis[i] = 0;
    for (auto v : cy)
        vis[v] = 1;
    for (auto v : cy)
        calc(v, 0);
    int tmpp = solve(0, 0, 0); // (2 4 3)
    // cout << dp[2][0] << " " << dp[4][3] << " " << dp[3][3] << " ";
    if (tmpp == mod)
        cout << -1;
    else
        cout << tmpp;
}

Compilation message

Main.cpp: In function 'int solve(int, int, int)':
Main.cpp:85:13: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   85 |     if (ind == cy.size())
      |         ~~~~^~~~~~~~~~~~
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 14424 KB Output is correct
2 Correct 2 ms 14428 KB Output is correct
3 Correct 2 ms 14428 KB Output is correct
4 Correct 2 ms 14440 KB Output is correct
5 Correct 78 ms 25852 KB Output is correct
6 Correct 82 ms 26068 KB Output is correct
7 Correct 81 ms 26080 KB Output is correct
8 Correct 91 ms 26068 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Runtime error 502 ms 524288 KB Execution killed with signal 9
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Runtime error 502 ms 524288 KB Execution killed with signal 9
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 14424 KB Output is correct
2 Correct 2 ms 14428 KB Output is correct
3 Correct 2 ms 14428 KB Output is correct
4 Correct 2 ms 14440 KB Output is correct
5 Correct 78 ms 25852 KB Output is correct
6 Correct 82 ms 26068 KB Output is correct
7 Correct 81 ms 26080 KB Output is correct
8 Correct 91 ms 26068 KB Output is correct
9 Runtime error 502 ms 524288 KB Execution killed with signal 9
10 Halted 0 ms 0 KB -