From 4132b956e2655bdd56726381e9bbcdd8a6648149 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Mon, 24 Jun 2024 14:08:52 +0200
Subject: [PATCH] Add dedicated `Maximum::<u8, _>::max` implementation for AVX2

---
 lightmotif/src/pli/dispatch.rs      | 10 ++++++++
 lightmotif/src/pli/mod.rs           |  4 +++
 lightmotif/src/pli/platform/avx2.rs | 38 +++++++++++++++++++++++++++++
 3 files changed, 52 insertions(+)

diff --git a/lightmotif/src/pli/dispatch.rs b/lightmotif/src/pli/dispatch.rs
index 7fa8add..1dc505c 100644
--- a/lightmotif/src/pli/dispatch.rs
+++ b/lightmotif/src/pli/dispatch.rs
@@ -205,6 +205,16 @@ impl<A: Alphabet> Maximum<u8, <Dispatch as Backend>::LANES> for Pipeline<A, Disp
             _ => <Generic as Maximum<u8, <Dispatch as Backend>::LANES>>::argmax(&Generic, scores),
         }
     }
+
+    fn max(&self, scores: &StripedScores<u8, <Dispatch as Backend>::LANES>) -> Option<u8> {
+        match self.backend {
+            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+            Dispatch::Avx2 => Avx2::max_u8(scores),
+            // #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+            // Dispatch::Sse2 => Sse2::argmax(scores),
+            _ => <Generic as Maximum<u8, <Dispatch as Backend>::LANES>>::max(&Generic, scores),
+        }
+    }
 }
 
 impl<A: Alphabet> Threshold<f32, <Dispatch as Backend>::LANES> for Pipeline<A, Dispatch> {}
diff --git a/lightmotif/src/pli/mod.rs b/lightmotif/src/pli/mod.rs
index 80a053b..83f9bc1 100644
--- a/lightmotif/src/pli/mod.rs
+++ b/lightmotif/src/pli/mod.rs
@@ -492,6 +492,10 @@ impl<A: Alphabet> Maximum<u8, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {
     ) -> Option<MatrixCoordinates> {
         Avx2::argmax_u8(scores)
     }
+
+    fn max(&self, scores: &StripedScores<u8, <Avx2 as Backend>::LANES>) -> Option<u8> {
+        Avx2::max_u8(scores)
+    }
 }
 
 impl<A: Alphabet> Threshold<f32, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {}
diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs
index 02bc8d0..3df2136 100644
--- a/lightmotif/src/pli/platform/avx2.rs
+++ b/lightmotif/src/pli/platform/avx2.rs
@@ -477,6 +477,34 @@ unsafe fn argmax_u8_avx2(
     }
 }
 
+#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+#[target_feature(enable = "avx2")]
+unsafe fn max_u8_avx2(scores: &StripedScores<u8, <Avx2 as Backend>::LANES>) -> Option<u8> {
+    if scores.is_empty() {
+        None
+    } else {
+        let data = scores.matrix();
+        unsafe {
+            let mut dataptr = data[0].as_ptr();
+            // store the best scores for each column
+            let mut m = _mm256_setzero_si256();
+            // process all rows iteratively
+            for i in 0..data.rows() {
+                // load scores for the current row
+                let r = _mm256_load_si256(dataptr as *const _);
+                // find highest score
+                m = _mm256_max_epu8(m, r);
+                // advance to next row
+                dataptr = dataptr.add(data.stride());
+            }
+            // find the global maximum across all columns
+            let mut x: [u8; 32] = [0; 32];
+            _mm256_storeu_si256(x.as_mut_ptr() as *mut _, m);
+            x.into_iter().max()
+        }
+    }
+}
+
 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
 #[target_feature(enable = "avx2")]
 unsafe fn stripe_avx2<A>(
@@ -865,6 +893,16 @@ impl Avx2 {
         panic!("attempting to run AVX2 code on a non-x86 host")
     }
 
+    #[allow(unused)]
+    pub fn max_u8(scores: &StripedScores<u8, <Avx2 as Backend>::LANES>) -> Option<u8> {
+        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
+        unsafe {
+            max_u8_avx2(scores)
+        }
+        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
+        panic!("attempting to run AVX2 code on a non-x86 host")
+    }
+
     #[allow(unused)]
     pub fn stripe_into<A, S>(seq: S, matrix: &mut StripedSequence<A, <Avx2 as Backend>::LANES>)
     where
-- 
GitLab