This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
#define all(x) (x).begin(), (x).end()
#define ff first
#define ss second
#define pb push_back
#define mp make_pair
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef pair<ull, ull> pull;
typedef pair<int, int> pii;
typedef pair<ld, ld> pld;
ll mod = 998244353;
ll ans = 0;
vector<string> a[11];
const ll n = 62;
ll cnt[70][70];
map<char, ll> mpp;
ll dp1[70][70][70];
ll dp2[70][70][70][70];
ll dp3[70][70][70];
void solve(int cur){
for(int i = 0; i < n; ++i)
for(int j = 0; j < n; ++j)
cnt[i][j] = 0;
int sz = a[cur].size();
for(int i = 0; i < sz; ++i){
string u = a[cur][i];
reverse(all(u));
a[cur].pb(u);
}
sort(all(a[cur]));
a[cur].resize(unique(all(a[cur]))-a[cur].begin());
for(auto u : a[cur]){
cnt[mpp[u[0]]][mpp[u.back()]]++;
}
for(int i = 0; i < n; ++i)
for(int j = 0; j < n; ++j)
for(int k = 0; k < n; ++k){
dp1[i][j][k] = 0;
for(int mid = 0; mid < n; ++mid){
dp1[i][j][k] += cnt[mid][i]*cnt[mid][j]%mod*cnt[mid][k]%mod;
if(dp1[i][j][k] >= mod)
dp1[i][j][k] -= mod;
}
}
for(int i = 0; i < n; ++i)
for(int j = 0; j < n; ++j)
for(int k = 0; k < n; ++k)
for(int y = 0; y < n; ++y)
dp2[i][j][k][y] = dp1[i][j][k]*dp1[j][k][y]%mod;
for(int i = 0; i < n; ++i)
for(int j = 0; j < n; ++j)
for(int y = 0; y < n; ++y){
dp3[i][j][y] = 0;
for(int k = 0; k < n; ++k){
dp3[i][j][y] += dp2[i][j][k][y]*dp1[i][k][y]%mod;
if(dp3[i][j][y] >= mod)
dp3[i][j][y] -= mod;
}
}
for(int i = 0; i < n; ++i)
for(int j = 0; j < n; ++j)
for(int y = 0; y < n; ++y)
for(int fin = 0; fin < n; ++fin){
ans += dp3[i][j][y]*cnt[fin][i]%mod*cnt[fin][j]%mod*cnt[fin][y]%mod;
if(ans >= mod)
ans -= mod;
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
//freopen("in.txt", "r", stdin);
//freopen("out.txt", "w", stdout);
ll t;
cin >> t;
while(t--){
string s;
cin >> s;
a[s.size()].pb(s);
}
ll cnt = 0;
for(char i = 'a'; i <= 'z'; ++i)
mpp[i] = cnt++;
for(char i = 'A'; i <= 'Z'; ++i)
mpp[i] = cnt++;
for(char i = '0'; i <= '9'; ++i)
mpp[i] = cnt++;
for(ll i = 3; i <= 10; ++i)
solve(i);
cout << ans << '\n';
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |