#include <bits/stdc++.h>
#if defined(__AVX2__)
#include <immintrin.h>
#endif
using namespace std;
static inline bool in_bounds(long long x, int N) {
return 0 <= x && x < N;
}
// Count all triples where some index is next to a gap whose size equals its height.
// For each p, we try q = p + a[p] (p is the smaller endpoint) and q = p - a[p] (p is the larger endpoint).
// For each directed pair (m < M) with d = a[endpoint_with_height], enumerate r in O(1) many positions:
// - r between m and M: two possibilities
// - r outside on the appropriate side: two possibilities
static long long count_adjacent(const vector<int>& a) {
const int N = (int)a.size();
long long ans = 0;
for (int p = 0; p < N; ++p) {
// Case 1: q = p + a[p] (p is the smaller endpoint, pair is (m=p, M=q))
int qR = p + a[p];
if (qR < N) {
int m = p, M = qR, d = a[p];
int aq = a[qR];
// r between m and M:
// If aq = t with 1 <= t < d, then r = m + t, and a[r] must be d - t.
// Or aq = d - t, then r = m + (d - aq) and a[r] must be aq.
if (aq < d) {
int r1 = m + aq;
if (a[r1] == d - aq) ++ans;
int r2 = m + (d - aq);
if (a[r2] == aq) ++ans;
}
// r to the right of M:
// If aq = t, r = M + t, require a[r] = d + t.
long long r3 = (long long)M + aq;
if (r3 < N && a[(int)r3] == d + aq) ++ans;
// Or if aq = d + t (i.e., aq > d), let t = aq - d.
// Then r = M + t, require a[r] = t.
if (aq > d) {
int t = aq - d;
int r4 = M + t;
if (r4 < N && a[r4] == t) ++ans;
}
}
// Case 2: q = p - a[p] (p is the larger endpoint, pair is (m=q, M=p))
int qL = p - a[p];
if (qL >= 0) {
int m = qL, M = p, d = a[p];
int aq = a[qL];
// r between m and M:
if (aq < d) {
int r1 = m + aq;
if (a[r1] == d - aq) ++ans;
int r2 = m + (d - aq);
if (a[r2] == aq) ++ans;
}
// r to the left of m:
long long r3 = (long long)m - aq;
if (r3 >= 0 && a[(int)r3] == d + aq) ++ans;
if (aq > d) {
int t = aq - d;
int r4 = m - t;
if (r4 >= 0 && a[r4] == t) ++ans;
}
}
}
return ans;
}
// SIMD-accelerated count for the specific permutation:
// H[i] = k - j, H[j] = k - i, H[k] = j - i (i < j < k)
// Equivalent check for each pair (i, k):
// let L = H[k], R = H[i], j = i + L
// require: R + L == k - i, j < k, H[j] == R + L
// To avoid overlap with the adjacent-gap pass when j-i == k-j (i.e., H[i] == H[k]),
// we skip pairs with H[i] == H[k] here.
static long long count_specific_simd(const vector<int>& H) {
const int N = (int)H.size();
long long total = 0;
#if defined(__AVX2__)
const int* base = H.data();
for (int i = 0; i < N; ++i) {
int Hi = H[i];
__m256i vi = _mm256_set1_epi32(i);
__m256i vHi = _mm256_set1_epi32(Hi);
__m256i vN = _mm256_set1_epi32(N);
__m256i vInc = _mm256_set1_epi32(8);
int k = i + 1;
// Scalar head to align k on 8
for (; (k < N) && (k & 7); ++k) {
int Hk = H[k];
if (Hk == Hi) continue; // avoid overlap with adjacent-gap case when x == y
int j = i + Hk;
if (j >= k) continue;
int sum = Hi + Hk;
if (sum != k - i) continue;
if (H[j] == sum) ++total;
}
if (k >= N) continue;
// Initialize vector of ks: [k, k+1, ..., k+7]
__m256i vk = _mm256_setr_epi32(k+0, k+1, k+2, k+3, k+4, k+5, k+6, k+7);
for (; k + 7 < N; k += 8, vk = _mm256_add_epi32(vk, vInc)) {
// Load H[k] via gather
__m256i vHk = _mm256_i32gather_epi32(base, vk, 4);
// Mask out lanes with Hk == Hi (avoid x == y overlap)
__m256i maskNeq = _mm256_cmpeq_epi32(vHk, vHi);
maskNeq = _mm256_xor_si256(maskNeq, _mm256_set1_epi32(-1)); // invert: 1s where Hk != Hi
// sum = Hi + Hk, delta = k - i
__m256i vsum = _mm256_add_epi32(vHi, vHk);
__m256i vdelta = _mm256_sub_epi32(vk, vi);
__m256i maskEq = _mm256_cmpeq_epi32(vsum, vdelta);
// j = i + Hk
__m256i vj = _mm256_add_epi32(vi, vHk);
// Ensure j < k
__m256i maskJltK = _mm256_cmpgt_epi32(vk, vj);
// Ensure j < N (safe gather). Build mask and make a safe index where out-of-range -> 0
__m256i maskJin = _mm256_cmpgt_epi32(vN, vj); // (N > j) => j in [0..N-1]
__m256i vj_safe = _mm256_blendv_epi8(_mm256_setzero_si256(), vj, maskJin);
// Gather H[j]
__m256i vHj = _mm256_i32gather_epi32(base, vj_safe, 4);
__m256i maskHj = _mm256_cmpeq_epi32(vHj, vsum);
// Combine masks
__m256i maskAll = _mm256_and_si256(maskNeq, _mm256_and_si256(maskEq, _mm256_and_si256(maskJltK, maskJin)));
maskAll = _mm256_and_si256(maskAll, maskHj);
// Count lanes
int m = _mm256_movemask_ps(_mm256_castsi256_ps(maskAll));
total += __builtin_popcount((unsigned)m);
}
// Scalar tail
for (; k < N; ++k) {
int Hk = H[k];
if (Hk == Hi) continue; // avoid overlap for x == y
int j = i + Hk;
if (j >= k || j >= N) continue;
int sum = Hi + Hk;
if (sum != k - i) continue;
if (H[j] == sum) ++total;
}
}
#else
// Portable scalar fallback (still correct, just slower).
for (int i = 0; i < N; ++i) {
for (int k = i + 1; k < N; ++k) {
int Hi = H[i], Hk = H[k];
if (Hi == Hk) continue; // avoid overlap when j-i == k-j
int j = i + Hk;
if (j >= k || j >= N) continue;
int sum = Hi + Hk;
if (sum != k - i) continue;
if (H[j] == sum) ++total;
}
}
#endif
return total;
}
long long count_triples(std::vector<int> H) {
// Part I only.
// Ensure H values are valid as per problem statement (1..N-1).
const int N = (int)H.size();
// Adjacent-gap enumeration (O(N))
long long ans = count_adjacent(H);
// SIMD-accelerated specific-pattern enumeration (disjoint from the above by x!=y filter)
ans += count_specific_simd(H);
return ans;
}
#ifdef LOCAL_TEST
// Simple sanity test
int main() {
vector<int> H = {4,1,4,3,2,6,1};
cout << count_triples(H) << "\n"; // expected 3
return 0;
}
#endif
std::vector<int> construct_range(int M, int K) { return {};}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |