# | Time | Username | Problem | Language | Result | Execution time | Memory |
---|---|---|---|---|---|---|---|
1258990 | Sorting | Triple Peaks (IOI25_triples) | C++20 | 0 ms | 0 KiB |
static long long count_specific_simd(const vector<int>& H) {
const int N = (int)H.size();
long long total = 0;
#if defined(__AVX2__)
const int* base = H.data();
for (int i = 0; i < N; ++i) {
int Hi = H[i];
__m256i vi = _mm256_set1_epi32(i);
__m256i vHi = _mm256_set1_epi32(Hi);
__m256i vN = _mm256_set1_epi32(N);
__m256i vInc = _mm256_set1_epi32(8);
int k = i + 1;
// scalar head to 8-align
for (; (k < N) && (k & 7); ++k) {
int Hk = H[k];
int j = i + Hk;
if (j >= k || j >= N) continue;
int sum = Hi + Hk;
if (sum != k - i) continue;
if (H[j] == sum) ++total;
}
if (k >= N) continue;
__m256i vk = _mm256_setr_epi32(k+0, k+1, k+2, k+3, k+4, k+5, k+6, k+7);
for (; k + 7 < N; k += 8, vk = _mm256_add_epi32(vk, vInc)) {
__m256i vHk = _mm256_i32gather_epi32(base, vk, 4);
__m256i vsum = _mm256_add_epi32(vHi, vHk); // Hi + Hk
__m256i vdelta = _mm256_sub_epi32(vk, vi); // k - i
__m256i maskEq = _mm256_cmpeq_epi32(vsum, vdelta);
__m256i vj = _mm256_add_epi32(vi, vHk); // j = i + Hk
__m256i maskJltK = _mm256_cmpgt_epi32(vk, vj); // j < k
__m256i maskJin = _mm256_cmpgt_epi32(vN, vj); // j in [0..N-1]
__m256i vj_safe = _mm256_blendv_epi8(_mm256_setzero_si256(), vj, maskJin);
__m256i vHj = _mm256_i32gather_epi32(base, vj_safe, 4);
__m256i maskHj = _mm256_cmpeq_epi32(vHj, vsum);
__m256i maskAll = _mm256_and_si256(maskEq,
_mm256_and_si256(maskJltK,
_mm256_and_si256(maskJin, maskHj)));
int m = _mm256_movemask_ps(_mm256_castsi256_ps(maskAll));
total += __builtin_popcount((unsigned)m);
}
// scalar tail
for (; k < N; ++k) {
int Hk = H[k];
int j = i + Hk;
if (j >= k || j >= N) continue;
int sum = Hi + Hk;
if (sum != k - i) continue;
if (H[j] == sum) ++total;
}
}
#else
// scalar fallback
for (int i = 0; i < N; ++i) {
for (int k = i + 1; k < N; ++k) {
int Hi = H[i], Hk = H[k];
int j = i + Hk;
if (j >= k || j >= N) continue;
int sum = Hi + Hk;
if (sum != k - i) continue;
if (H[j] == sum) ++total;
}
}
#endif
return total;
}
std::vector<int> construct_range(int M, int K) { return {};}