제출 #1258988

#제출 시각아이디문제언어결과실행 시간메모리
1258988Sorting3개의 봉우리 (IOI25_triples)C++20
0 / 100
2095 ms1864 KiB
#include <bits/stdc++.h>
#if defined(__AVX2__)
  #include <immintrin.h>
#endif
using namespace std;

static long long count_adjacent_right_only(const vector<int>& H) {
    const int N = (int)H.size();
    long long ans = 0;

    for (int p = 0; p < N; ++p) {
        int d = H[p];
        int q = p + d;
        if (q >= N) continue; // only right neighbor -> p is the smallest index

        int aq = H[q];

        // r between p and q
        if (aq < d) {
            int rA = p + aq;
            int rB = p + (d - aq);
            if (rA == rB) {
                // aq == d - aq  <=>  2*aq == d  -> count once if H[r]==aq
                if (H[rA] == aq) ++ans;
            } else {
                if (H[rA] == d - aq) ++ans;
                if (H[rB] == aq) ++ans;
            }
        }

        // r to the right of q (outside)
        // Case: H[q] = t  and  H[r] = d + t, r = q + t
        long long r1 = (long long)q + aq;
        if (r1 < N && H[(int)r1] == d + aq) ++ans;

        // Case: H[q] = d + t  and  H[r] = t, r = q + t  (i.e. aq > d)
        if (aq > d) {
            int t = aq - d;
            int r2 = q + t;
            if (r2 < N && H[r2] == t) ++ans;
        }
    }
    return ans;
}

static long long count_specific_simd(const vector<int>& H) {
    const int N = (int)H.size();
    long long total = 0;

#if defined(__AVX2__)
    const int* base = H.data();
    for (int i = 0; i < N; ++i) {
        int Hi = H[i];
        __m256i vi   = _mm256_set1_epi32(i);
        __m256i vHi  = _mm256_set1_epi32(Hi);
        __m256i vN   = _mm256_set1_epi32(N);
        __m256i vInc = _mm256_set1_epi32(8);

        int k = i + 1;

        // scalar head to 8-align
        for (; (k < N) && (k & 7); ++k) {
            int Hk = H[k];
            if (Hk == Hi) continue;                // skip x==y case (adjacent pass handles)
            int j = i + Hk;
            if (j >= k || j >= N) continue;
            int sum = Hi + Hk;
            if (sum != k - i) continue;
            if (H[j] == sum) ++total;
        }

        if (k >= N) continue;
        __m256i vk = _mm256_setr_epi32(k+0, k+1, k+2, k+3, k+4, k+5, k+6, k+7);

        for (; k + 7 < N; k += 8, vk = _mm256_add_epi32(vk, vInc)) {
            __m256i vHk = _mm256_i32gather_epi32(base, vk, 4);
            // mask Hk != Hi
            __m256i maskNeq = _mm256_xor_si256(_mm256_cmpeq_epi32(vHk, vHi), _mm256_set1_epi32(-1));

            __m256i vsum   = _mm256_add_epi32(vHi, vHk);     // Hi + Hk
            __m256i vdelta = _mm256_sub_epi32(vk, vi);       // k - i
            __m256i maskEq = _mm256_cmpeq_epi32(vsum, vdelta);

            __m256i vj = _mm256_add_epi32(vi, vHk);          // j = i + Hk
            __m256i maskJltK = _mm256_cmpgt_epi32(vk, vj);   // j < k
            __m256i maskJin  = _mm256_cmpgt_epi32(vN, vj);   // j in [0..N-1]
            __m256i vj_safe  = _mm256_blendv_epi8(_mm256_setzero_si256(), vj, maskJin);

            __m256i vHj = _mm256_i32gather_epi32(base, vj_safe, 4);
            __m256i maskHj = _mm256_cmpeq_epi32(vHj, vsum);

            __m256i maskAll = _mm256_and_si256(maskNeq,
                                 _mm256_and_si256(maskEq,
                                   _mm256_and_si256(maskJltK,
                                     _mm256_and_si256(maskJin, maskHj))));

            int m = _mm256_movemask_ps(_mm256_castsi256_ps(maskAll));
            total += __builtin_popcount((unsigned)m);
        }

        // scalar tail
        for (; k < N; ++k) {
            int Hk = H[k];
            if (Hk == Hi) continue;                // skip x==y case
            int j = i + Hk;
            if (j >= k || j >= N) continue;
            int sum = Hi + Hk;
            if (sum != k - i) continue;
            if (H[j] == sum) ++total;
        }
    }
#else
    // scalar fallback
    for (int i = 0; i < N; ++i) {
        for (int k = i + 1; k < N; ++k) {
            int Hi = H[i], Hk = H[k];
            if (Hi == Hk) continue;                 // skip x==y case
            int j = i + Hk;
            if (j >= k || j >= N) continue;
            int sum = Hi + Hk;
            if (sum != k - i) continue;
            if (H[j] == sum) ++total;
        }
    }
#endif
    return total;
}

long long count_triples(std::vector<int> H) {
    long long ans = 0;
    ans += count_adjacent_right_only(H);  // deduped, no left mirror, handles x==y and many others
    ans += count_specific_simd(H);        // disjoint (skips x==y), AVX2-accelerated
    return ans;
}

#ifdef LOCAL_TEST
int main() {
    vector<int> H = {4,1,4,3,2,6,1};
    cout << count_triples(H) << "\n"; // 3
}
#endif
std::vector<int> construct_range(int M, int K) { return {};}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...