#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define pb push_back
#define dbg(x) cerr << #x << " " << x << "\n"
const int N = 100;
int dp[1 + 2 * N][1 + 2 * N][1 + 2 * N][3][2]; /// A B C : 0/1 done
int freqA[1 + 3 * N], freqB[1 + 3 * N], freqC[1 + 3 * N];
int f[1 + 3 * N];
const int MOD = 1e9 + 7;
void add (int &a, int b) {
a += b;
if (a >= MOD)
a -= MOD;
}
int getDp (int ab, int bc, int ca, int turn, bool moved) {
if (ab + bc + ca == max ({ab, bc, ca})) {
return f[max ({ab, bc, ca})];
}
if (dp[ab][bc][ca][turn][moved]) return dp[ab][bc][ca][turn][moved];
int ans = 0;
if (turn == 0) { /// A -> B or A -> C
if (ab > 0)
add (ans, 1ll * ab * getDp (ab - 1, bc, ca, turn + 1, true) % MOD);
if (ca > 0)
add (ans, 1ll * ca * getDp (ab, bc + 1, ca - 1, turn + 1, moved) % MOD);
if (ab == 0 && ca == 0)
add (ans, getDp (ab, bc, ca, turn + 1, moved));
}
if (turn == 1) { /// B -> A or B -> C
if (bc > 0)
add (ans, 1ll * bc * getDp (ab, bc - 1, ca, turn + 1, true) % MOD);
if (ab > 0)
add (ans, 1ll * ab * getDp (ab - 1, bc, ca + 1, turn + 1, moved) % MOD);
if (ab == 0 && ca == 0)
add (ans, getDp (ab, bc, ca, turn + 1, moved));
}
if (turn == 2) { /// C -> B or C -> A
if (ca > 0)
add (ans, 1ll * ca * getDp (ab, bc, ca - 1, 0, 0) % MOD);
if (moved) {
if (bc > 0)
add (ans, 1ll * bc * getDp (ab + 1, bc - 1, ca, 0, 0) % MOD);
if (ca == 0 && ab == 0)
add (ans, getDp (ab, bc, ca, 0, 0));
}
}
return dp[ab][bc][ca][turn][moved] = ans;
}
void solveTest (int n) {
memset (dp, 0, sizeof (dp));
f[0] = 1;
for (int i = 1; i <= 3 * n; i++)
f[i] = 1ll * f[i - 1] * i % MOD;
for (int i = 1; i <= 3 * n; i++)
freqA[i] = freqB[i] = freqC[i] = 0;
for (int i = 0; i < 2 * n; i++) {
int x;
cin >> x;
freqA[x]++;
}
for (int i = 0; i < 2 * n; i++) {
int x;
cin >> x;
freqB[x]++;
}
for (int i = 0; i < 2 * n; i++) {
int x;
cin >> x;
freqC[x]++;
}
int AB = 0, BC = 0, CA = 0;
for (int i = 1; i <= 3 * n; i++) {
if (freqA[i] && freqB[i])
AB++;
if (freqB[i] && freqC[i])
BC++;
if (freqC[i] && freqA[i])
CA++;
}
/// dp[AB][BC][CA][0][0] = 1;
cout << getDp (AB, BC, CA, 0, false) << "\n";
// int total = AB + BC + CA;
/** for (int sum = total; sum > 0; sum--)
for (int ab = 0; ab <= min (AB, sum); ab++)
for (int bc = 0; bc <= min (BC, sum - ab); bc++) {
int ca = sum - ab - bc;
if (ca <= CA) {
for (int turn = 0; turn < 3; turn++) {
for (int moved = 0; moved < 2; moved++)
pushDp (ab, bc, ca, turn, moved);
}
}
}**/
}
int main () {
int n, t;
cin >> n >> t;
while (t--)
solveTest (n);
return 0;
}
/**
1 1
1 2
3 3
2 1
**/
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
171 ms |
190968 KB |
Output is correct |
2 |
Correct |
220 ms |
191096 KB |
Output is correct |
3 |
Correct |
251 ms |
191096 KB |
Output is correct |
4 |
Correct |
230 ms |
190968 KB |
Output is correct |
5 |
Correct |
448 ms |
191252 KB |
Output is correct |
6 |
Correct |
487 ms |
191096 KB |
Output is correct |
7 |
Correct |
548 ms |
190968 KB |
Output is correct |
8 |
Correct |
655 ms |
191084 KB |
Output is correct |
9 |
Correct |
763 ms |
191224 KB |
Output is correct |
10 |
Correct |
896 ms |
191224 KB |
Output is correct |