답안 #854981

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
854981 2023-09-29T14:39:22 Z Tenis0206 W (RMI18_w) C++11
100 / 100
284 ms 5204 KB
#include <bits/stdc++.h>

using namespace std;

const int nmax = 3e5;
const int Mod = 1e9 + 7;

int n;
int v[nmax + 5];

pair<int,int> c[nmax + 5];

int inv[nmax + 5];

int dp[(1<<5) + 5], aux[(1<<5) + 5];

vector<int> valid_mask;
vector<int> valid_last_mask[(1<<5) + 5];

int lgput(int a, int b)
{
    int p = 1;
    while(b)
    {
        if(b%2==0)
        {
            b/=2;
            a = 1LL * a * a % Mod;
        }
        else
        {
            --b;
            p = 1LL * p * a % Mod;
        }
    }
    return p;
}

int invmod(int x)
{
    return lgput(x,Mod-2);
}

int comb(int n, int k)
{
    int rez = 1;
    for(int i=k+1; i<=n; i++)
    {
        rez = 1LL * rez * i % Mod;
    }
    for(int i=1; i<=n-k; i++)
    {
        rez = 1LL * rez * inv[i] % Mod;
    }
    return rez;
}

bool valid(int mask)
{
    if((mask & 1) != 0 && (mask & 2) == 0)
    {
        return false;
    }
    if((mask & 4) != 0 && (mask & 2) == 0)
    {
        return false;
    }
    if((mask & 4) != 0 && (mask & 8) == 0)
    {
        return false;
    }
    if((mask & 16) != 0 && (mask & 8) == 0)
    {
        return false;
    }
    return true;
}

bool valid_last(int last_mask, int mask)
{
    if((last_mask & 1) == 0 && (mask & 1) != 0 && (last_mask & 2) == 0 && (mask & 2) != 0)
    {
        return false;
    }
    if((last_mask & 2) == 0 && (mask & 2) != 0 && (last_mask & 4) == 0 && (mask & 4) != 0)
    {
        return false;
    }
    if((last_mask & 4) == 0 && (mask & 4) != 0 && (last_mask & 8) == 0 && (mask & 8) != 0)
    {
        return false;
    }
    if((last_mask & 8) == 0 && (mask & 8) != 0 && (last_mask & 16) == 0 && (mask & 16) != 0)
    {
        return false;
    }
    return true;
}

int get_opt(int last_mask, int mask)
{
    int nr = 0;

    if((mask & 2) != 0 && (last_mask & 1) == 0 && !((last_mask & 1) == 0 && (mask & 1) != 0) && !((last_mask & 2) == 0 && (mask & 2) != 0))
    {
        ++nr;
    }
    if((mask & 2) != 0 && (last_mask & 4) == 0 && !((last_mask & 2) == 0 && (mask & 2) != 0) && !((last_mask & 4) == 0 && (mask & 4) != 0))
    {
        ++nr;
    }
    if((mask & 8) != 0 && (last_mask & 4) == 0 && !((last_mask & 4) == 0 && (mask & 4) != 0) && !((last_mask & 8) == 0 && (mask & 8) != 0))
    {
        ++nr;
    }
    if((mask & 8) != 0 && (last_mask & 16) == 0 && !((last_mask & 8) == 0 && (mask & 8) != 0) && !((last_mask & 16) == 0 && (mask & 16) != 0))
    {
        ++nr;
    }

    if((last_mask & 1) == 0 && (mask & 1) != 0)
    {
        ++nr;
    }
    if((last_mask & 2) == 0 && (mask & 2) != 0)
    {
        ++nr;
    }
    if((last_mask & 4) == 0 && (mask & 4) != 0)
    {
        ++nr;
    }
    if((last_mask & 8) == 0 && (mask & 8) != 0)
    {
        ++nr;
    }
    if((last_mask & 16) == 0 && (mask & 16) != 0)
    {
        ++nr;
    }

    return nr;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
#ifdef home
    freopen("nr.in","r",stdin);
    freopen("nr.out","w",stdout);
#endif // home
    cin>>n;
    for(int i=1; i<=n; i++)
    {
        cin>>v[i];
        inv[i] = invmod(i);
    }
    sort(v+1,v+n+1);
    int nr = 0;
    for(int i=1; i<=n; i++)
    {
        if(v[i]!=v[i-1])
        {
            c[++nr].first = v[i];
        }
        ++c[nr].second;
    }
    for(int mask=0; mask<(1<<5); mask++)
    {
        if(!valid[mask])
        {
            continue;
        }
        valid_mask.push_back(mask);
        for(int last_mask=0; last_mask<(1<<5); last_mask++)
        {
            if(!valid(last_mask))
            {
                continue;
            }
            if((mask & last_mask) != last_mask)
            {
                continue;
            }
            if(!valid_last(last_mask, mask))
            {
                continue;
            }
            valid_last_mask[mask].push_back(last_mask);
        }
    }
    dp[0] = 1;
    for(int i=1; i<=nr; i++)
    {
        for(auto mask : valid_mask)
        {
            for(auto last_mask : valid_last_mask[mask])
            {
                int nr_free = c[i].second - (__builtin_popcount(mask) - __builtin_popcount(last_mask));
                if(nr_free < 0)
                {
                    continue;
                }
                int nropt = get_opt(last_mask, mask);
                aux[mask] += 1LL * comb(nr_free + nropt - 1, nropt - 1) * dp[last_mask] % Mod;
                aux[mask] %= Mod;
            }
        }
        for(int mask=0; mask<(1<<5); mask++)
        {
            dp[mask] = aux[mask];
            aux[mask] = 0;
        }
    }
    cout<<dp[(1<<5) - 1]<<'\n';
    return 0;
}

Compilation message

w.cpp: In function 'int main()':
w.cpp:171:23: warning: pointer to a function used in arithmetic [-Wpointer-arith]
  171 |         if(!valid[mask])
      |                       ^
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 4444 KB Output is correct
2 Correct 1 ms 4444 KB Output is correct
3 Correct 45 ms 4728 KB Output is correct
4 Correct 251 ms 4944 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 4444 KB Output is correct
2 Correct 2 ms 4444 KB Output is correct
3 Correct 31 ms 4676 KB Output is correct
4 Correct 97 ms 4700 KB Output is correct
5 Correct 187 ms 4740 KB Output is correct
6 Correct 282 ms 4948 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 4444 KB Output is correct
2 Correct 1 ms 4440 KB Output is correct
3 Correct 1 ms 4444 KB Output is correct
4 Correct 1 ms 4444 KB Output is correct
5 Correct 2 ms 4444 KB Output is correct
6 Correct 5 ms 4444 KB Output is correct
7 Correct 9 ms 4604 KB Output is correct
8 Correct 88 ms 4700 KB Output is correct
9 Correct 176 ms 4740 KB Output is correct
10 Correct 284 ms 5204 KB Output is correct