제출 #1258986

#제출 시각아이디문제언어결과실행 시간메모리
1258986Sorting3개의 봉우리 (IOI25_triples)C++20
0 / 100
2095 ms1864 KiB
#include <bits/stdc++.h> #if defined(__AVX2__) #include <immintrin.h> #endif using namespace std; static inline bool in_bounds(long long x, int N) { return 0 <= x && x < N; } // Count all triples where some index is next to a gap whose size equals its height. // For each p, we try q = p + a[p] (p is the smaller endpoint) and q = p - a[p] (p is the larger endpoint). // For each directed pair (m < M) with d = a[endpoint_with_height], enumerate r in O(1) many positions: // - r between m and M: two possibilities // - r outside on the appropriate side: two possibilities static long long count_adjacent(const vector<int>& a) { const int N = (int)a.size(); long long ans = 0; for (int p = 0; p < N; ++p) { // Case 1: q = p + a[p] (p is the smaller endpoint, pair is (m=p, M=q)) int qR = p + a[p]; if (qR < N) { int m = p, M = qR, d = a[p]; int aq = a[qR]; // r between m and M: // If aq = t with 1 <= t < d, then r = m + t, and a[r] must be d - t. // Or aq = d - t, then r = m + (d - aq) and a[r] must be aq. if (aq < d) { int r1 = m + aq; if (a[r1] == d - aq) ++ans; int r2 = m + (d - aq); if (a[r2] == aq) ++ans; } // r to the right of M: // If aq = t, r = M + t, require a[r] = d + t. long long r3 = (long long)M + aq; if (r3 < N && a[(int)r3] == d + aq) ++ans; // Or if aq = d + t (i.e., aq > d), let t = aq - d. // Then r = M + t, require a[r] = t. if (aq > d) { int t = aq - d; int r4 = M + t; if (r4 < N && a[r4] == t) ++ans; } } // Case 2: q = p - a[p] (p is the larger endpoint, pair is (m=q, M=p)) int qL = p - a[p]; if (qL >= 0) { int m = qL, M = p, d = a[p]; int aq = a[qL]; // r between m and M: if (aq < d) { int r1 = m + aq; if (a[r1] == d - aq) ++ans; int r2 = m + (d - aq); if (a[r2] == aq) ++ans; } // r to the left of m: long long r3 = (long long)m - aq; if (r3 >= 0 && a[(int)r3] == d + aq) ++ans; if (aq > d) { int t = aq - d; int r4 = m - t; if (r4 >= 0 && a[r4] == t) ++ans; } } } return ans; } // SIMD-accelerated count for the specific permutation: // H[i] = k - j, H[j] = k - i, H[k] = j - i (i < j < k) // Equivalent check for each pair (i, k): // let L = H[k], R = H[i], j = i + L // require: R + L == k - i, j < k, H[j] == R + L // To avoid overlap with the adjacent-gap pass when j-i == k-j (i.e., H[i] == H[k]), // we skip pairs with H[i] == H[k] here. static long long count_specific_simd(const vector<int>& H) { const int N = (int)H.size(); long long total = 0; #if defined(__AVX2__) const int* base = H.data(); for (int i = 0; i < N; ++i) { int Hi = H[i]; __m256i vi = _mm256_set1_epi32(i); __m256i vHi = _mm256_set1_epi32(Hi); __m256i vN = _mm256_set1_epi32(N); __m256i vInc = _mm256_set1_epi32(8); int k = i + 1; // Scalar head to align k on 8 for (; (k < N) && (k & 7); ++k) { int Hk = H[k]; if (Hk == Hi) continue; // avoid overlap with adjacent-gap case when x == y int j = i + Hk; if (j >= k) continue; int sum = Hi + Hk; if (sum != k - i) continue; if (H[j] == sum) ++total; } if (k >= N) continue; // Initialize vector of ks: [k, k+1, ..., k+7] __m256i vk = _mm256_setr_epi32(k+0, k+1, k+2, k+3, k+4, k+5, k+6, k+7); for (; k + 7 < N; k += 8, vk = _mm256_add_epi32(vk, vInc)) { // Load H[k] via gather __m256i vHk = _mm256_i32gather_epi32(base, vk, 4); // Mask out lanes with Hk == Hi (avoid x == y overlap) __m256i maskNeq = _mm256_cmpeq_epi32(vHk, vHi); maskNeq = _mm256_xor_si256(maskNeq, _mm256_set1_epi32(-1)); // invert: 1s where Hk != Hi // sum = Hi + Hk, delta = k - i __m256i vsum = _mm256_add_epi32(vHi, vHk); __m256i vdelta = _mm256_sub_epi32(vk, vi); __m256i maskEq = _mm256_cmpeq_epi32(vsum, vdelta); // j = i + Hk __m256i vj = _mm256_add_epi32(vi, vHk); // Ensure j < k __m256i maskJltK = _mm256_cmpgt_epi32(vk, vj); // Ensure j < N (safe gather). Build mask and make a safe index where out-of-range -> 0 __m256i maskJin = _mm256_cmpgt_epi32(vN, vj); // (N > j) => j in [0..N-1] __m256i vj_safe = _mm256_blendv_epi8(_mm256_setzero_si256(), vj, maskJin); // Gather H[j] __m256i vHj = _mm256_i32gather_epi32(base, vj_safe, 4); __m256i maskHj = _mm256_cmpeq_epi32(vHj, vsum); // Combine masks __m256i maskAll = _mm256_and_si256(maskNeq, _mm256_and_si256(maskEq, _mm256_and_si256(maskJltK, maskJin))); maskAll = _mm256_and_si256(maskAll, maskHj); // Count lanes int m = _mm256_movemask_ps(_mm256_castsi256_ps(maskAll)); total += __builtin_popcount((unsigned)m); } // Scalar tail for (; k < N; ++k) { int Hk = H[k]; if (Hk == Hi) continue; // avoid overlap for x == y int j = i + Hk; if (j >= k || j >= N) continue; int sum = Hi + Hk; if (sum != k - i) continue; if (H[j] == sum) ++total; } } #else // Portable scalar fallback (still correct, just slower). for (int i = 0; i < N; ++i) { for (int k = i + 1; k < N; ++k) { int Hi = H[i], Hk = H[k]; if (Hi == Hk) continue; // avoid overlap when j-i == k-j int j = i + Hk; if (j >= k || j >= N) continue; int sum = Hi + Hk; if (sum != k - i) continue; if (H[j] == sum) ++total; } } #endif return total; } long long count_triples(std::vector<int> H) { // Part I only. // Ensure H values are valid as per problem statement (1..N-1). const int N = (int)H.size(); // Adjacent-gap enumeration (O(N)) long long ans = count_adjacent(H); // SIMD-accelerated specific-pattern enumeration (disjoint from the above by x!=y filter) ans += count_specific_simd(H); return ans; } #ifdef LOCAL_TEST // Simple sanity test int main() { vector<int> H = {4,1,4,3,2,6,1}; cout << count_triples(H) << "\n"; // expected 3 return 0; } #endif std::vector<int> construct_range(int M, int K) { return {};}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...