From 00556952b29986bf324985f6729a0152c89bbd88 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Thu, 20 Jun 2024 14:12:35 +0200
Subject: [PATCH] Add AVX2 implementation of `Score<u8>` operation

---
 lightmotif/src/pli/dispatch.rs      | 48 ++++++++++++++-
 lightmotif/src/pli/mod.rs           | 55 ++++++++++++++++-
 lightmotif/src/pli/platform/avx2.rs | 91 +++++++++++++++++++++++++++--
 3 files changed, 183 insertions(+), 11 deletions(-)

diff --git a/lightmotif/src/pli/dispatch.rs b/lightmotif/src/pli/dispatch.rs
index 4c6af58..2879088 100644
--- a/lightmotif/src/pli/dispatch.rs
+++ b/lightmotif/src/pli/dispatch.rs
@@ -87,7 +87,7 @@ impl Score<f32, Dna, <Dispatch as Backend>::LANES> for Pipeline<Dna, Dispatch> {
     {
         match self.backend {
             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
-            Dispatch::Avx2 => Avx2::score_rows_into_permute(pssm, seq.as_ref(), rows, scores),
+            Dispatch::Avx2 => Avx2::score_f32_rows_into_permute(pssm, seq.as_ref(), rows, scores),
             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
             Dispatch::Sse2 => Sse2::score_rows_into(pssm, seq.as_ref(), rows, scores),
             #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
@@ -116,7 +116,7 @@ impl Score<f32, Protein, <Dispatch as Backend>::LANES> for Pipeline<Protein, Dis
     {
         match self.backend {
             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
-            Dispatch::Avx2 => Avx2::score_rows_into_gather(pssm, seq.as_ref(), rows, scores),
+            Dispatch::Avx2 => Avx2::score_f32_rows_into_gather(pssm, seq.as_ref(), rows, scores),
             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
             Dispatch::Sse2 => Sse2::score_rows_into(pssm, seq.as_ref(), rows, scores),
             #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
@@ -132,6 +132,35 @@ impl Score<f32, Protein, <Dispatch as Backend>::LANES> for Pipeline<Protein, Dis
     }
 }
 
+impl Score<u8, Dna, <Dispatch as Backend>::LANES> for Pipeline<Dna, Dispatch> {
+    fn score_rows_into<S, M>(
+        &self,
+        pssm: M,
+        seq: S,
+        rows: Range<usize>,
+        scores: &mut StripedScores<u8, <Dispatch as Backend>::LANES>,
+    ) where
+        S: AsRef<StripedSequence<Dna, <Dispatch as Backend>::LANES>>,
+        M: AsRef<DenseMatrix<u8, <Dna as Alphabet>::K>>,
+    {
+        match self.backend {
+            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+            Dispatch::Avx2 => Avx2::score_u8_rows_into_shuffle(pssm, seq.as_ref(), rows, scores),
+            // #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+            // Dispatch::Sse2 => Sse2::score_rows_into(pssm, seq.as_ref(), rows, scores),
+            // #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
+            // Dispatch::Neon => Neon::score_rows_into(pssm, seq.as_ref(), rows, scores),
+            _ => <Generic as Score<u8, Dna, <Dispatch as Backend>::LANES>>::score_rows_into(
+                &Generic,
+                pssm,
+                seq.as_ref(),
+                rows,
+                scores,
+            ),
+        }
+    }
+}
+
 impl<A: Alphabet> Stripe<A, <Dispatch as Backend>::LANES> for Pipeline<A, Dispatch> {
     fn stripe_into<S: AsRef<[A::Symbol]>>(
         &self,
@@ -163,4 +192,19 @@ impl<A: Alphabet> Maximum<f32, <Dispatch as Backend>::LANES> for Pipeline<A, Dis
     }
 }
 
+impl<A: Alphabet> Maximum<u8, <Dispatch as Backend>::LANES> for Pipeline<A, Dispatch> {
+    fn argmax(
+        &self,
+        scores: &StripedScores<u8, <Dispatch as Backend>::LANES>,
+    ) -> Option<MatrixCoordinates> {
+        match self.backend {
+            // #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+            // Dispatch::Avx2 => Avx2::argmax(scores),
+            // #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+            // Dispatch::Sse2 => Sse2::argmax(scores),
+            _ => <Generic as Maximum<u8, <Dispatch as Backend>::LANES>>::argmax(&Generic, scores),
+        }
+    }
+}
+
 impl<A: Alphabet> Threshold<f32, <Dispatch as Backend>::LANES> for Pipeline<A, Dispatch> {}
diff --git a/lightmotif/src/pli/mod.rs b/lightmotif/src/pli/mod.rs
index 8df645e..9c3a510 100644
--- a/lightmotif/src/pli/mod.rs
+++ b/lightmotif/src/pli/mod.rs
@@ -309,6 +309,7 @@ impl<A: Alphabet> Pipeline<A, Dispatch> {
 
 impl<A: Alphabet> Pipeline<A, Sse2> {
     /// Attempt to create a new SSE2-accelerated pipeline.
+    #[inline]
     pub fn sse2() -> Result<Self, UnsupportedBackend> {
         #[cfg(target_arch = "x86")]
         if std::is_x86_feature_detected!("sse2") {
@@ -328,6 +329,7 @@ where
     A: Alphabet,
     C: StrictlyPositive + MultipleOf<U16>,
 {
+    #[inline]
     fn score_rows_into<S, M>(
         &self,
         pssm: M,
@@ -342,17 +344,39 @@ where
     }
 }
 
+impl<A, C> Score<u8, A, C> for Pipeline<A, Sse2>
+where
+    A: Alphabet,
+    C: StrictlyPositive + MultipleOf<U16>,
+{
+}
+
 impl<A, C> Maximum<f32, C> for Pipeline<A, Sse2>
 where
     A: Alphabet,
     C: StrictlyPositive + MultipleOf<U16>,
 {
+    #[inline]
     fn argmax(&self, scores: &StripedScores<f32, C>) -> Option<MatrixCoordinates> {
         Sse2::argmax(scores)
     }
 }
 
-impl<A: Alphabet, C: StrictlyPositive> Threshold<f32, C> for Pipeline<A, Sse2>
+impl<A, C> Maximum<u8, C> for Pipeline<A, Sse2>
+where
+    A: Alphabet,
+    C: StrictlyPositive + MultipleOf<U16>,
+{
+}
+
+impl<A, C> Threshold<f32, C> for Pipeline<A, Sse2>
+where
+    A: Alphabet,
+    C: StrictlyPositive + MultipleOf<U16>,
+{
+}
+
+impl<A, C> Threshold<u8, C> for Pipeline<A, Sse2>
 where
     A: Alphabet,
     C: StrictlyPositive + MultipleOf<U16>,
@@ -373,6 +397,7 @@ impl<A: Alphabet> Pipeline<A, Avx2> {
 }
 
 impl<A: Alphabet> Encode<A> for Pipeline<A, Avx2> {
+    #[inline]
     fn encode_into<S: AsRef<[u8]>>(
         &self,
         seq: S,
@@ -383,6 +408,7 @@ impl<A: Alphabet> Encode<A> for Pipeline<A, Avx2> {
 }
 
 impl Score<f32, Dna, <Avx2 as Backend>::LANES> for Pipeline<Dna, Avx2> {
+    #[inline]
     fn score_rows_into<S, M>(
         &self,
         pssm: M,
@@ -393,11 +419,12 @@ impl Score<f32, Dna, <Avx2 as Backend>::LANES> for Pipeline<Dna, Avx2> {
         S: AsRef<StripedSequence<Dna, <Avx2 as Backend>::LANES>>,
         M: AsRef<DenseMatrix<f32, <Dna as Alphabet>::K>>,
     {
-        Avx2::score_rows_into_permute(pssm, seq, rows, scores)
+        Avx2::score_f32_rows_into_permute(pssm, seq, rows, scores)
     }
 }
 
 impl Score<f32, Protein, <Avx2 as Backend>::LANES> for Pipeline<Protein, Avx2> {
+    #[inline]
     fn score_rows_into<S, M>(
         &self,
         pssm: M,
@@ -408,12 +435,29 @@ impl Score<f32, Protein, <Avx2 as Backend>::LANES> for Pipeline<Protein, Avx2> {
         S: AsRef<StripedSequence<Protein, <Avx2 as Backend>::LANES>>,
         M: AsRef<DenseMatrix<f32, <Protein as Alphabet>::K>>,
     {
-        Avx2::score_rows_into_gather(pssm, seq, rows, scores)
+        Avx2::score_f32_rows_into_gather(pssm, seq, rows, scores)
+    }
+}
+
+impl Score<u8, Dna, <Avx2 as Backend>::LANES> for Pipeline<Dna, Avx2> {
+    #[inline]
+    fn score_rows_into<S, M>(
+        &self,
+        pssm: M,
+        seq: S,
+        rows: Range<usize>,
+        scores: &mut StripedScores<u8, <Avx2 as Backend>::LANES>,
+    ) where
+        S: AsRef<StripedSequence<Dna, <Avx2 as Backend>::LANES>>,
+        M: AsRef<DenseMatrix<u8, <Dna as Alphabet>::K>>,
+    {
+        Avx2::score_u8_rows_into_shuffle(pssm, seq, rows, scores)
     }
 }
 
 impl<A: Alphabet> Stripe<A, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {
     /// Stripe a sequence into the given striped matrix.
+    #[inline]
     fn stripe_into<S: AsRef<[A::Symbol]>>(
         &self,
         seq: S,
@@ -431,13 +475,17 @@ impl<A: Alphabet> Maximum<f32, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {
         Avx2::argmax(scores)
     }
 }
+impl<A: Alphabet> Maximum<u8, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {}
 
 impl<A: Alphabet> Threshold<f32, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {}
 
+impl<A: Alphabet> Threshold<u8, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {}
+
 // --- NEON pipeline -----------------------------------------------------------
 
 impl<A: Alphabet> Pipeline<A, Neon> {
     /// Attempt to create a new AVX2-accelerated pipeline.
+    #[inline]
     pub fn neon() -> Result<Self, UnsupportedBackend> {
         #[cfg(target_arch = "arm")]
         if std::arch::is_arm_feature_detected!("neon") {
@@ -452,6 +500,7 @@ impl<A: Alphabet> Pipeline<A, Neon> {
 }
 
 impl<A: Alphabet> Encode<A> for Pipeline<A, Neon> {
+    #[inline]
     fn encode_into<S: AsRef<[u8]>>(
         &self,
         seq: S,
diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs
index 4842dc8..7a1dc3a 100644
--- a/lightmotif/src/pli/platform/avx2.rs
+++ b/lightmotif/src/pli/platform/avx2.rs
@@ -14,6 +14,7 @@ use crate::err::InvalidSymbol;
 use crate::num::IsLessOrEqual;
 use crate::num::NonZero;
 use crate::num::Unsigned;
+use crate::num::U16;
 use crate::num::U32;
 use crate::num::U5;
 use crate::num::U8;
@@ -96,7 +97,7 @@ where
 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
 #[target_feature(enable = "avx2")]
 #[allow(overflowing_literals)]
-unsafe fn score_avx2_permute<A>(
+unsafe fn score_f32_avx2_permute<A>(
     pssm: &DenseMatrix<f32, A::K>,
     seq: &StripedSequence<A, <Avx2 as Backend>::LANES>,
     rows: Range<usize>,
@@ -189,7 +190,7 @@ unsafe fn score_avx2_permute<A>(
 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
 #[target_feature(enable = "avx2")]
 #[allow(overflowing_literals)]
-unsafe fn score_avx2_gather<A>(
+unsafe fn score_f32_avx2_gather<A>(
     pssm: &DenseMatrix<f32, A::K>,
     seq: &StripedSequence<A, <Avx2 as Backend>::LANES>,
     rows: Range<usize>,
@@ -272,6 +273,47 @@ unsafe fn score_avx2_gather<A>(
     _mm_sfence();
 }
 
+#[target_feature(enable = "avx2")]
+pub unsafe fn score_u8_avx2_shuffle<A>(
+    pssm: &DenseMatrix<u8, A::K>,
+    seq: &StripedSequence<A, <Avx2 as Backend>::LANES>,
+    rows: Range<usize>,
+    scores: &mut StripedScores<u8, <Avx2 as Backend>::LANES>,
+) where
+    A: Alphabet,
+{
+    let data = scores.matrix_mut();
+    let mut rowptr = data[0].as_mut_ptr() as *mut i8;
+    // process every position of the sequence data
+    for i in rows {
+        // reset sums for current position
+        let mut s = _mm256_setzero_si256();
+        // reset pointers to row
+        let mut seqptr = seq.matrix()[i].as_ptr();
+        let mut pssmptr = pssm[0].as_ptr();
+        // advance position in the position weight matrix
+        for _ in 0..pssm.rows() {
+            // load sequence row and broadcast to f32
+            let x = _mm256_load_si256(seqptr as *const __m256i);
+            // load row for current weight matrix position
+            // NB: we need to broadcast it to the two lanes of the __m256i vector
+            //     because in AVX2 shuffle operates on the two halves independently.
+            let t = _mm256_broadcastsi128_si256(_mm_load_si128(&*(pssmptr as *const __m128i)));
+            // load scores for given sequence
+            let x = _mm256_shuffle_epi8(t, x);
+            // add scores to the running sum
+            s = _mm256_adds_epu8(s, x);
+            // advance to next row in PSSM and sequence matrices
+            seqptr = seqptr.add(seq.matrix().stride());
+            pssmptr = pssmptr.add(pssm.stride());
+        }
+        // record the score for the current position
+        _mm256_stream_si256(rowptr as *mut __m256i, s);
+        rowptr = rowptr.add(data.stride());
+    }
+    _mm_sfence();
+}
+
 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
 #[target_feature(enable = "avx2")]
 unsafe fn argmax_avx2(
@@ -614,7 +656,7 @@ impl Avx2 {
     }
 
     #[allow(unused)]
-    pub fn score_rows_into_permute<A, S, M>(
+    pub fn score_f32_rows_into_permute<A, S, M>(
         pssm: M,
         seq: S,
         rows: Range<usize>,
@@ -644,14 +686,14 @@ impl Avx2 {
         scores.resize(rows.len(), (seq.len() + 1).saturating_sub(pssm.rows()));
         #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
         unsafe {
-            score_avx2_permute(pssm, seq, rows, scores)
+            score_f32_avx2_permute(pssm, seq, rows, scores)
         };
         #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
         panic!("attempting to run AVX2 code on a non-x86 host")
     }
 
     #[allow(unused)]
-    pub fn score_rows_into_gather<A, S, M>(
+    pub fn score_f32_rows_into_gather<A, S, M>(
         pssm: M,
         seq: S,
         rows: Range<usize>,
@@ -679,7 +721,44 @@ impl Avx2 {
         scores.resize(rows.len(), (seq.len() + 1).saturating_sub(pssm.rows()));
         #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
         unsafe {
-            score_avx2_gather(pssm, seq, rows, scores)
+            score_f32_avx2_gather(pssm, seq, rows, scores)
+        };
+        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
+        panic!("attempting to run AVX2 code on a non-x86 host")
+    }
+
+    #[allow(unused)]
+    pub fn score_u8_rows_into_shuffle<A, S, M>(
+        pssm: M,
+        seq: S,
+        rows: Range<usize>,
+        scores: &mut StripedScores<u8, <Avx2 as Backend>::LANES>,
+    ) where
+        A: Alphabet,
+        <A as Alphabet>::K: IsLessOrEqual<U16>,
+        <<A as Alphabet>::K as IsLessOrEqual<U16>>::Output: NonZero,
+        S: AsRef<StripedSequence<A, <Avx2 as Backend>::LANES>>,
+        M: AsRef<DenseMatrix<u8, A::K>>,
+    {
+        let seq = seq.as_ref();
+        let pssm = pssm.as_ref();
+
+        if seq.wrap() < pssm.rows() - 1 {
+            panic!(
+                "not enough wrapping rows for motif of length {}",
+                pssm.rows()
+            );
+        }
+
+        if seq.len() < pssm.rows() || rows.len() == 0 {
+            scores.resize(0, 0);
+            return;
+        }
+
+        scores.resize(rows.len(), (seq.len() + 1).saturating_sub(pssm.rows()));
+        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+        unsafe {
+            score_u8_avx2_shuffle(pssm, seq, rows, scores)
         };
         #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
         panic!("attempting to run AVX2 code on a non-x86 host")
-- 
GitLab