From 641f89a50b6c252c722a3d6ecca1acd8affb2f71 Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Fri, 21 Jun 2024 16:24:31 +0200 Subject: [PATCH] Fix SSE2 and AVX2 implementations of `argmax` not returning the last index --- lightmotif/src/pli/platform/avx2.rs | 12 +++---- lightmotif/src/pli/platform/sse2.rs | 54 +++++++++++++++-------------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs index 0f5efe2..02bc8d0 100644 --- a/lightmotif/src/pli/platform/avx2.rs +++ b/lightmotif/src/pli/platform/avx2.rs @@ -357,10 +357,10 @@ unsafe fn argmax_f32_avx2( let r3 = _mm256_load_ps(dataptr.add(0x10)); let r4 = _mm256_load_ps(dataptr.add(0x18)); // compare scores to local maximums - let c1 = _mm256_cmp_ps(s1, r1, _CMP_LT_OS); - let c2 = _mm256_cmp_ps(s2, r2, _CMP_LT_OS); - let c3 = _mm256_cmp_ps(s3, r3, _CMP_LT_OS); - let c4 = _mm256_cmp_ps(s4, r4, _CMP_LT_OS); + let c1 = _mm256_cmp_ps(s1, r1, _CMP_LE_OS); + let c2 = _mm256_cmp_ps(s2, r2, _CMP_LE_OS); + let c3 = _mm256_cmp_ps(s3, r3, _CMP_LE_OS); + let c4 = _mm256_cmp_ps(s4, r4, _CMP_LE_OS); // replace indices of new local maximums p1 = _mm256_blendv_ps(p1, index, c1); p2 = _mm256_blendv_ps(p2, index, c2); @@ -442,8 +442,8 @@ unsafe fn argmax_u8_avx2( let r1 = _mm256_unpacklo_epi8(r, _mm256_setzero_si256()); let r2 = _mm256_unpackhi_epi8(r, _mm256_setzero_si256()); // compare scores to local maximums - let c1 = _mm256_cmpgt_epi16(s1, r1); - let c2 = _mm256_cmpgt_epi16(s2, r2); + let c1 = _mm256_or_si256(_mm256_cmpeq_epi16(r1, s1), _mm256_cmpgt_epi16(r1, s1)); + let c2 = _mm256_or_si256(_mm256_cmpeq_epi16(r2, s2), _mm256_cmpgt_epi16(r2, s2)); // replace indices of new local maximums p1 = _mm256_blendv_epi8(p1, index, c1); p2 = _mm256_blendv_epi8(p2, index, c2); diff --git a/lightmotif/src/pli/platform/sse2.rs b/lightmotif/src/pli/platform/sse2.rs index 212da5b..0c614ec 100644 --- a/lightmotif/src/pli/platform/sse2.rs +++ b/lightmotif/src/pli/platform/sse2.rs @@ -25,6 +25,8 @@ use crate::pwm::ScoringMatrix; use crate::scores::StripedScores; use crate::seq::StripedSequence; +use generic_array::ArrayLength; + /// A marker type for the SSE2 implementation of the pipeline. #[derive(Clone, Debug, Default)] pub struct Sse2; @@ -168,9 +170,11 @@ unsafe fn score_sse2<A: Alphabet, C: MultipleOf<<Sse2 as Backend>::LANES>>( #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[target_feature(enable = "sse2")] -unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES>>( +unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES> + ArrayLength>( scores: &StripedScores<f32, C>, ) -> Option<MatrixCoordinates> { + use generic_array::{ArrayLength, GenericArray}; + if scores.max_index() > u32::MAX as usize { panic!( "This implementation only supports sequences with at most {} positions, found a sequence with {} positions. Contact the developers at https://github.com/althonos/lightmotif.", @@ -181,13 +185,14 @@ unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES>>( } else { let data = scores.matrix(); unsafe { - let mut output = [0u32; 16]; + let mut output = GenericArray::<u32, C>::default(); let mut best_col = 0; let mut best_row = 0; let mut best_score = -f32::INFINITY; for offset in (0..C::Quotient::USIZE).map(|i| i * <Sse2 as Backend>::LANES::USIZE) { let mut dataptr = data[0].as_ptr().add(offset); + let mut outptr = output.as_mut_ptr().add(offset); // the row index for the best score in each column // (these are 32-bit integers but for use with `_mm256_blendv_ps` // they get stored in 32-bit float vectors). @@ -196,10 +201,10 @@ unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES>>( let mut p3 = _mm_setzero_ps(); let mut p4 = _mm_setzero_ps(); // store the best scores for each column - let mut s1 = _mm_load_ps(dataptr.add(0x00)); - let mut s2 = _mm_load_ps(dataptr.add(0x04)); - let mut s3 = _mm_load_ps(dataptr.add(0x08)); - let mut s4 = _mm_load_ps(dataptr.add(0x0c)); + let mut s1 = _mm_set1_ps(best_score); + let mut s2 = _mm_set1_ps(best_score); + let mut s3 = _mm_set1_ps(best_score); + let mut s4 = _mm_set1_ps(best_score); // process all rows iteratively for i in 0..data.rows() { // record the current row index @@ -210,10 +215,10 @@ unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES>>( let r3 = _mm_load_ps(dataptr.add(0x08)); let r4 = _mm_load_ps(dataptr.add(0x0c)); // compare scores to local maxima - let c1 = _mm_cmplt_ps(s1, r1); - let c2 = _mm_cmplt_ps(s2, r2); - let c3 = _mm_cmplt_ps(s3, r3); - let c4 = _mm_cmplt_ps(s4, r4); + let c1 = _mm_cmple_ps(s1, r1); + let c2 = _mm_cmple_ps(s2, r2); + let c3 = _mm_cmple_ps(s3, r3); + let c4 = _mm_cmple_ps(s4, r4); // NOTE: code below could use `_mm_blendv_ps` instead, // but this instruction is only available on SSE4.1 // while the rest of the code is actually using SSE2 @@ -232,21 +237,18 @@ unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES>>( dataptr = dataptr.add(data.stride()); } // find the global maximum across all columns - _mm_storeu_si128(output[0x00..].as_mut_ptr() as *mut _, _mm_castps_si128(p1)); - _mm_storeu_si128(output[0x04..].as_mut_ptr() as *mut _, _mm_castps_si128(p2)); - _mm_storeu_si128(output[0x08..].as_mut_ptr() as *mut _, _mm_castps_si128(p3)); - _mm_storeu_si128(output[0x0c..].as_mut_ptr() as *mut _, _mm_castps_si128(p4)); - for k in 0..U16::USIZE { - let row = output[k] as usize; - let col = k + offset; - let score = data[row][col]; - if score > best_score - || (score == best_score && (row, col) < (best_row, best_col)) - { - best_score = data[row][col]; - best_row = row; - best_col = col; - } + _mm_storeu_si128(outptr.add(0x00) as *mut _, _mm_castps_si128(p1)); + _mm_storeu_si128(outptr.add(0x04) as *mut _, _mm_castps_si128(p2)); + _mm_storeu_si128(outptr.add(0x08) as *mut _, _mm_castps_si128(p3)); + _mm_storeu_si128(outptr.add(0x0c) as *mut _, _mm_castps_si128(p4)); + } + for col in 0..C::USIZE { + let row = output[col] as usize; + let score = data[row][col]; + if score >= best_score { + best_score = score; + best_row = row; + best_col = col; } } Some(MatrixCoordinates::new(best_row, best_col)) @@ -305,7 +307,7 @@ impl Sse2 { } #[allow(unused)] - pub fn argmax<C: MultipleOf<<Sse2 as Backend>::LANES>>( + pub fn argmax<C: MultipleOf<<Sse2 as Backend>::LANES> + ArrayLength>( scores: &StripedScores<f32, C>, ) -> Option<MatrixCoordinates> { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -- GitLab