From 641f89a50b6c252c722a3d6ecca1acd8affb2f71 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Fri, 21 Jun 2024 16:24:31 +0200
Subject: [PATCH] Fix SSE2 and AVX2 implementations of `argmax` not returning
 the last index

---
 lightmotif/src/pli/platform/avx2.rs | 12 +++----
 lightmotif/src/pli/platform/sse2.rs | 54 +++++++++++++++--------------
 2 files changed, 34 insertions(+), 32 deletions(-)

diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs
index 0f5efe2..02bc8d0 100644
--- a/lightmotif/src/pli/platform/avx2.rs
+++ b/lightmotif/src/pli/platform/avx2.rs
@@ -357,10 +357,10 @@ unsafe fn argmax_f32_avx2(
                 let r3 = _mm256_load_ps(dataptr.add(0x10));
                 let r4 = _mm256_load_ps(dataptr.add(0x18));
                 // compare scores to local maximums
-                let c1 = _mm256_cmp_ps(s1, r1, _CMP_LT_OS);
-                let c2 = _mm256_cmp_ps(s2, r2, _CMP_LT_OS);
-                let c3 = _mm256_cmp_ps(s3, r3, _CMP_LT_OS);
-                let c4 = _mm256_cmp_ps(s4, r4, _CMP_LT_OS);
+                let c1 = _mm256_cmp_ps(s1, r1, _CMP_LE_OS);
+                let c2 = _mm256_cmp_ps(s2, r2, _CMP_LE_OS);
+                let c3 = _mm256_cmp_ps(s3, r3, _CMP_LE_OS);
+                let c4 = _mm256_cmp_ps(s4, r4, _CMP_LE_OS);
                 // replace indices of new local maximums
                 p1 = _mm256_blendv_ps(p1, index, c1);
                 p2 = _mm256_blendv_ps(p2, index, c2);
@@ -442,8 +442,8 @@ unsafe fn argmax_u8_avx2(
                 let r1 = _mm256_unpacklo_epi8(r, _mm256_setzero_si256());
                 let r2 = _mm256_unpackhi_epi8(r, _mm256_setzero_si256());
                 // compare scores to local maximums
-                let c1 = _mm256_cmpgt_epi16(s1, r1);
-                let c2 = _mm256_cmpgt_epi16(s2, r2);
+                let c1 = _mm256_or_si256(_mm256_cmpeq_epi16(r1, s1), _mm256_cmpgt_epi16(r1, s1));
+                let c2 = _mm256_or_si256(_mm256_cmpeq_epi16(r2, s2), _mm256_cmpgt_epi16(r2, s2));
                 // replace indices of new local maximums
                 p1 = _mm256_blendv_epi8(p1, index, c1);
                 p2 = _mm256_blendv_epi8(p2, index, c2);
diff --git a/lightmotif/src/pli/platform/sse2.rs b/lightmotif/src/pli/platform/sse2.rs
index 212da5b..0c614ec 100644
--- a/lightmotif/src/pli/platform/sse2.rs
+++ b/lightmotif/src/pli/platform/sse2.rs
@@ -25,6 +25,8 @@ use crate::pwm::ScoringMatrix;
 use crate::scores::StripedScores;
 use crate::seq::StripedSequence;
 
+use generic_array::ArrayLength;
+
 /// A marker type for the SSE2 implementation of the pipeline.
 #[derive(Clone, Debug, Default)]
 pub struct Sse2;
@@ -168,9 +170,11 @@ unsafe fn score_sse2<A: Alphabet, C: MultipleOf<<Sse2 as Backend>::LANES>>(
 
 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
 #[target_feature(enable = "sse2")]
-unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES>>(
+unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES> + ArrayLength>(
     scores: &StripedScores<f32, C>,
 ) -> Option<MatrixCoordinates> {
+    use generic_array::{ArrayLength, GenericArray};
+
     if scores.max_index() > u32::MAX as usize {
         panic!(
             "This implementation only supports sequences with at most {} positions, found a sequence with {} positions. Contact the developers at https://github.com/althonos/lightmotif.",
@@ -181,13 +185,14 @@ unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES>>(
     } else {
         let data = scores.matrix();
         unsafe {
-            let mut output = [0u32; 16];
+            let mut output = GenericArray::<u32, C>::default();
             let mut best_col = 0;
             let mut best_row = 0;
             let mut best_score = -f32::INFINITY;
 
             for offset in (0..C::Quotient::USIZE).map(|i| i * <Sse2 as Backend>::LANES::USIZE) {
                 let mut dataptr = data[0].as_ptr().add(offset);
+                let mut outptr = output.as_mut_ptr().add(offset);
                 // the row index for the best score in each column
                 // (these are 32-bit integers but for use with `_mm256_blendv_ps`
                 // they get stored in 32-bit float vectors).
@@ -196,10 +201,10 @@ unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES>>(
                 let mut p3 = _mm_setzero_ps();
                 let mut p4 = _mm_setzero_ps();
                 // store the best scores for each column
-                let mut s1 = _mm_load_ps(dataptr.add(0x00));
-                let mut s2 = _mm_load_ps(dataptr.add(0x04));
-                let mut s3 = _mm_load_ps(dataptr.add(0x08));
-                let mut s4 = _mm_load_ps(dataptr.add(0x0c));
+                let mut s1 = _mm_set1_ps(best_score);
+                let mut s2 = _mm_set1_ps(best_score);
+                let mut s3 = _mm_set1_ps(best_score);
+                let mut s4 = _mm_set1_ps(best_score);
                 // process all rows iteratively
                 for i in 0..data.rows() {
                     // record the current row index
@@ -210,10 +215,10 @@ unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES>>(
                     let r3 = _mm_load_ps(dataptr.add(0x08));
                     let r4 = _mm_load_ps(dataptr.add(0x0c));
                     // compare scores to local maxima
-                    let c1 = _mm_cmplt_ps(s1, r1);
-                    let c2 = _mm_cmplt_ps(s2, r2);
-                    let c3 = _mm_cmplt_ps(s3, r3);
-                    let c4 = _mm_cmplt_ps(s4, r4);
+                    let c1 = _mm_cmple_ps(s1, r1);
+                    let c2 = _mm_cmple_ps(s2, r2);
+                    let c3 = _mm_cmple_ps(s3, r3);
+                    let c4 = _mm_cmple_ps(s4, r4);
                     // NOTE: code below could use `_mm_blendv_ps` instead,
                     //       but this instruction is only available on SSE4.1
                     //       while the rest of the code is actually using SSE2
@@ -232,21 +237,18 @@ unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES>>(
                     dataptr = dataptr.add(data.stride());
                 }
                 // find the global maximum across all columns
-                _mm_storeu_si128(output[0x00..].as_mut_ptr() as *mut _, _mm_castps_si128(p1));
-                _mm_storeu_si128(output[0x04..].as_mut_ptr() as *mut _, _mm_castps_si128(p2));
-                _mm_storeu_si128(output[0x08..].as_mut_ptr() as *mut _, _mm_castps_si128(p3));
-                _mm_storeu_si128(output[0x0c..].as_mut_ptr() as *mut _, _mm_castps_si128(p4));
-                for k in 0..U16::USIZE {
-                    let row = output[k] as usize;
-                    let col = k + offset;
-                    let score = data[row][col];
-                    if score > best_score
-                        || (score == best_score && (row, col) < (best_row, best_col))
-                    {
-                        best_score = data[row][col];
-                        best_row = row;
-                        best_col = col;
-                    }
+                _mm_storeu_si128(outptr.add(0x00) as *mut _, _mm_castps_si128(p1));
+                _mm_storeu_si128(outptr.add(0x04) as *mut _, _mm_castps_si128(p2));
+                _mm_storeu_si128(outptr.add(0x08) as *mut _, _mm_castps_si128(p3));
+                _mm_storeu_si128(outptr.add(0x0c) as *mut _, _mm_castps_si128(p4));
+            }
+            for col in 0..C::USIZE {
+                let row = output[col] as usize;
+                let score = data[row][col];
+                if score >= best_score {
+                    best_score = score;
+                    best_row = row;
+                    best_col = col;
                 }
             }
             Some(MatrixCoordinates::new(best_row, best_col))
@@ -305,7 +307,7 @@ impl Sse2 {
     }
 
     #[allow(unused)]
-    pub fn argmax<C: MultipleOf<<Sse2 as Backend>::LANES>>(
+    pub fn argmax<C: MultipleOf<<Sse2 as Backend>::LANES> + ArrayLength>(
         scores: &StripedScores<f32, C>,
     ) -> Option<MatrixCoordinates> {
         #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
-- 
GitLab