From b167120d0c59c19cd15ecae1d2a9e4538ed16ead Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Fri, 21 Jun 2024 11:20:02 +0200
Subject: [PATCH] Add SSE2 implementation of the `Encode` pipeline operation

---
 lightmotif/src/pli/mod.rs           | 11 +++-
 lightmotif/src/pli/platform/avx2.rs | 14 +++--
 lightmotif/src/pli/platform/sse2.rs | 84 +++++++++++++++++++++++++++++
 3 files changed, 103 insertions(+), 6 deletions(-)

diff --git a/lightmotif/src/pli/mod.rs b/lightmotif/src/pli/mod.rs
index 056f933..b6e8f4e 100644
--- a/lightmotif/src/pli/mod.rs
+++ b/lightmotif/src/pli/mod.rs
@@ -323,7 +323,16 @@ impl<A: Alphabet> Pipeline<A, Sse2> {
     }
 }
 
-impl<A: Alphabet> Encode<A> for Pipeline<A, Sse2> {}
+impl<A: Alphabet> Encode<A> for Pipeline<A, Sse2> {
+    #[inline]
+    fn encode_into<S: AsRef<[u8]>>(
+        &self,
+        seq: S,
+        dst: &mut [A::Symbol],
+    ) -> Result<(), InvalidSymbol> {
+        Sse2::encode_into::<A>(seq.as_ref(), dst)
+    }
+}
 
 impl<A, C> Score<f32, A, C> for Pipeline<A, Sse2>
 where
diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs
index 1e731f0..0f5efe2 100644
--- a/lightmotif/src/pli/platform/avx2.rs
+++ b/lightmotif/src/pli/platform/avx2.rs
@@ -41,6 +41,8 @@ unsafe fn encode_into_avx2<A>(seq: &[u8], dst: &mut [A::Symbol]) -> Result<(), I
 where
     A: Alphabet,
 {
+    const STRIDE: usize = std::mem::size_of::<__m256i>();
+
     let alphabet = A::as_str().as_bytes();
     let g = Pipeline::<A, _>::generic();
     let l = seq.len();
@@ -56,7 +58,7 @@ where
         let mut error = _mm256_setzero_si256();
 
         // Process the beginning of the sequence in SIMD while possible.
-        while i + std::mem::size_of::<__m256i>() < l {
+        while i + STRIDE <= l {
             // Load current row and reset buffers for the encoded result.
             let letters = _mm256_loadu_si256(src_ptr as *const __m256i);
             let mut encoded = _mm256_set1_epi8(A::K::USIZE as i8);
@@ -74,9 +76,9 @@ where
             // Store the encoded result to the output buffer.
             _mm256_storeu_si256(dst_ptr as *mut __m256i, encoded);
             // Advance to the next addresses in input and output.
-            src_ptr = src_ptr.add(std::mem::size_of::<__m256i>());
-            dst_ptr = dst_ptr.add(std::mem::size_of::<__m256i>());
-            i += std::mem::size_of::<__m256i>();
+            src_ptr = src_ptr.add(STRIDE);
+            dst_ptr = dst_ptr.add(STRIDE);
+            i += STRIDE;
         }
 
         // If an invalid symbol was encountered, recover which one.
@@ -88,7 +90,9 @@ where
         }
 
         // Encode the rest of the sequence using the generic implementation.
-        g.encode_into(&seq[i..], &mut dst[i..])?;
+        if i < l {
+            g.encode_into(&seq[i..], &mut dst[i..])?;
+        }
     }
 
     Ok(())
diff --git a/lightmotif/src/pli/platform/sse2.rs b/lightmotif/src/pli/platform/sse2.rs
index 4609440..212da5b 100644
--- a/lightmotif/src/pli/platform/sse2.rs
+++ b/lightmotif/src/pli/platform/sse2.rs
@@ -11,12 +11,16 @@ use std::ops::Rem;
 
 use super::Backend;
 use crate::abc::Alphabet;
+use crate::abc::Symbol;
 use crate::dense::DenseMatrix;
 use crate::dense::MatrixCoordinates;
+use crate::err::InvalidSymbol;
 use crate::num::consts::U16;
 use crate::num::MultipleOf;
 use crate::num::StrictlyPositive;
 use crate::num::Unsigned;
+use crate::pli::Encode;
+use crate::pli::Pipeline;
 use crate::pwm::ScoringMatrix;
 use crate::scores::StripedScores;
 use crate::seq::StripedSequence;
@@ -29,6 +33,73 @@ impl Backend for Sse2 {
     type LANES = U16;
 }
 
+#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+#[target_feature(enable = "sse2")]
+#[allow(overflowing_literals)]
+unsafe fn encode_into_sse2<A>(seq: &[u8], dst: &mut [A::Symbol]) -> Result<(), InvalidSymbol>
+where
+    A: Alphabet,
+{
+    const STRIDE: usize = std::mem::size_of::<__m128i>();
+
+    let alphabet = A::as_str().as_bytes();
+    let g = Pipeline::<A, _>::generic();
+    let l = seq.len();
+    assert_eq!(seq.len(), dst.len());
+
+    unsafe {
+        // Use raw pointers since we cannot be sure `seq` and `dst` are aligned.
+        let mut i = 0;
+        let mut src_ptr = seq.as_ptr();
+        let mut dst_ptr = dst.as_mut_ptr();
+
+        // Store a flag to know if invalid letters have been encountered.
+        let mut error = _mm_setzero_si128();
+
+        // Process the beginning of the sequence in SIMD while possible.
+        while i + STRIDE < l {
+            // Load current row and reset buffers for the encoded result.
+            let letters = _mm_loadu_si128(src_ptr as *const __m128i);
+            let mut encoded = _mm_set1_epi8((A::K::USIZE - 1) as i8);
+            let mut unknown = _mm_set1_epi8(0xFF);
+            // Check symbols one by one and match them to the letters.
+            for a in 0..A::K::USIZE {
+                let index = _mm_set1_epi8(a as i8);
+                let ascii = _mm_set1_epi8(alphabet[a] as i8);
+                let m = _mm_cmpeq_epi8(letters, ascii);
+                encoded = _mm_or_si128(_mm_andnot_si128(m, encoded), _mm_and_si128(m, index));
+                unknown = _mm_andnot_si128(m, unknown);
+            }
+            // Record is some symbols of the current vector are unknown.
+            error = _mm_or_si128(error, unknown);
+            // Store the encoded result to the output buffer.
+            _mm_storeu_si128(dst_ptr as *mut __m128i, encoded);
+            // Advance to the next addresses in input and output.
+            src_ptr = src_ptr.add(STRIDE);
+            dst_ptr = dst_ptr.add(STRIDE);
+            i += STRIDE;
+        }
+
+        // If an invalid symbol was encountered, recover which one.
+        // FIXME: run a vectorize the error search?
+        let mut x: [u8; 16] = [0; 16];
+        _mm_storeu_si128(x.as_mut_ptr() as *mut __m128i, error);
+        if x.iter().any(|&x| x != 0) {
+            for i in 0..l {
+                let _ = A::Symbol::from_ascii(seq[i])?;
+            }
+        }
+
+        // Encode the rest of the sequence using the generic implementation.
+        // g.encode_into(&seq[i..], &mut dst[i..])?;
+        if i < l {
+            g.encode_into(&seq[i..], &mut dst[i..])?;
+        }
+    }
+
+    Ok(())
+}
+
 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
 #[target_feature(enable = "sse2")]
 unsafe fn score_sse2<A: Alphabet, C: MultipleOf<<Sse2 as Backend>::LANES>>(
@@ -184,6 +255,19 @@ unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES>>(
 }
 
 impl Sse2 {
+    #[allow(unused)]
+    pub fn encode_into<A>(seq: &[u8], dst: &mut [A::Symbol]) -> Result<(), InvalidSymbol>
+    where
+        A: Alphabet,
+    {
+        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+        unsafe {
+            encode_into_sse2::<A>(seq, dst)
+        }
+        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
+        panic!("attempting to run SSE2 code on a non-x86 host");
+    }
+
     #[allow(unused)]
     pub fn score_rows_into<A, C, S, M>(
         pssm: M,
-- 
GitLab