#include <bits/stdc++.h>
#if defined(__AVX2__)
#include <immintrin.h>
#endif
using namespace std;
static long long count_adjacent_right_only(const vector<int>& H) {
const int N = (int)H.size();
long long ans = 0;
for (int p = 0; p < N; ++p) {
int d = H[p];
int q = p + d;
if (q >= N) continue; // only right neighbor -> p is the smallest index
int aq = H[q];
// r between p and q
if (aq < d) {
int rA = p + aq;
int rB = p + (d - aq);
if (rA == rB) {
// aq == d - aq <=> 2*aq == d -> count once if H[r]==aq
if (H[rA] == aq) ++ans;
} else {
if (H[rA] == d - aq) ++ans;
if (H[rB] == aq) ++ans;
}
}
// r to the right of q (outside)
// Case: H[q] = t and H[r] = d + t, r = q + t
long long r1 = (long long)q + aq;
if (r1 < N && H[(int)r1] == d + aq) ++ans;
// Case: H[q] = d + t and H[r] = t, r = q + t (i.e. aq > d)
if (aq > d) {
int t = aq - d;
int r2 = q + t;
if (r2 < N && H[r2] == t) ++ans;
}
}
return ans;
}
static long long count_specific_simd(const vector<int>& H) {
const int N = (int)H.size();
long long total = 0;
#if defined(__AVX2__)
const int* base = H.data();
for (int i = 0; i < N; ++i) {
int Hi = H[i];
__m256i vi = _mm256_set1_epi32(i);
__m256i vHi = _mm256_set1_epi32(Hi);
__m256i vN = _mm256_set1_epi32(N);
__m256i vInc = _mm256_set1_epi32(8);
int k = i + 1;
// scalar head to 8-align
for (; (k < N) && (k & 7); ++k) {
int Hk = H[k];
if (Hk == Hi) continue; // skip x==y case (adjacent pass handles)
int j = i + Hk;
if (j >= k || j >= N) continue;
int sum = Hi + Hk;
if (sum != k - i) continue;
if (H[j] == sum) ++total;
}
if (k >= N) continue;
__m256i vk = _mm256_setr_epi32(k+0, k+1, k+2, k+3, k+4, k+5, k+6, k+7);
for (; k + 7 < N; k += 8, vk = _mm256_add_epi32(vk, vInc)) {
__m256i vHk = _mm256_i32gather_epi32(base, vk, 4);
// mask Hk != Hi
__m256i maskNeq = _mm256_xor_si256(_mm256_cmpeq_epi32(vHk, vHi), _mm256_set1_epi32(-1));
__m256i vsum = _mm256_add_epi32(vHi, vHk); // Hi + Hk
__m256i vdelta = _mm256_sub_epi32(vk, vi); // k - i
__m256i maskEq = _mm256_cmpeq_epi32(vsum, vdelta);
__m256i vj = _mm256_add_epi32(vi, vHk); // j = i + Hk
__m256i maskJltK = _mm256_cmpgt_epi32(vk, vj); // j < k
__m256i maskJin = _mm256_cmpgt_epi32(vN, vj); // j in [0..N-1]
__m256i vj_safe = _mm256_blendv_epi8(_mm256_setzero_si256(), vj, maskJin);
__m256i vHj = _mm256_i32gather_epi32(base, vj_safe, 4);
__m256i maskHj = _mm256_cmpeq_epi32(vHj, vsum);
__m256i maskAll = _mm256_and_si256(maskNeq,
_mm256_and_si256(maskEq,
_mm256_and_si256(maskJltK,
_mm256_and_si256(maskJin, maskHj))));
int m = _mm256_movemask_ps(_mm256_castsi256_ps(maskAll));
total += __builtin_popcount((unsigned)m);
}
// scalar tail
for (; k < N; ++k) {
int Hk = H[k];
if (Hk == Hi) continue; // skip x==y case
int j = i + Hk;
if (j >= k || j >= N) continue;
int sum = Hi + Hk;
if (sum != k - i) continue;
if (H[j] == sum) ++total;
}
}
#else
// scalar fallback
for (int i = 0; i < N; ++i) {
for (int k = i + 1; k < N; ++k) {
int Hi = H[i], Hk = H[k];
if (Hi == Hk) continue; // skip x==y case
int j = i + Hk;
if (j >= k || j >= N) continue;
int sum = Hi + Hk;
if (sum != k - i) continue;
if (H[j] == sum) ++total;
}
}
#endif
return total;
}
long long count_triples(std::vector<int> H) {
long long ans = 0;
ans += count_adjacent_right_only(H); // deduped, no left mirror, handles x==y and many others
ans += count_specific_simd(H); // disjoint (skips x==y), AVX2-accelerated
return ans;
}
#ifdef LOCAL_TEST
int main() {
vector<int> H = {4,1,4,3,2,6,1};
cout << count_triples(H) << "\n"; // 3
}
#endif
std::vector<int> construct_range(int M, int K) { return {};}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |