From 3d0912231b3c96d0c876a16a0784ba7c8bb40c1c Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Thu, 20 Jun 2024 22:58:13 +0200
Subject: [PATCH] Fix build of platform specific code on Arm platforms

---
 lightmotif/src/pli/dispatch.rs      |  4 ++--
 lightmotif/src/pli/platform/avx2.rs |  1 +
 lightmotif/src/pli/platform/neon.rs | 32 +++++++++++++++--------------
 3 files changed, 20 insertions(+), 17 deletions(-)

diff --git a/lightmotif/src/pli/dispatch.rs b/lightmotif/src/pli/dispatch.rs
index 6351413..7e93906 100644
--- a/lightmotif/src/pli/dispatch.rs
+++ b/lightmotif/src/pli/dispatch.rs
@@ -91,7 +91,7 @@ impl Score<f32, Dna, <Dispatch as Backend>::LANES> for Pipeline<Dna, Dispatch> {
             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
             Dispatch::Sse2 => Sse2::score_rows_into(pssm, seq.as_ref(), rows, scores),
             #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
-            Dispatch::Neon => Neon::score_rows_into(pssm, seq.as_ref(), rows, scores),
+            Dispatch::Neon => Neon::score_f32_rows_into(pssm, seq.as_ref(), rows, scores),
             _ => <Generic as Score<f32, Dna, <Dispatch as Backend>::LANES>>::score_rows_into(
                 &Generic,
                 pssm,
@@ -120,7 +120,7 @@ impl Score<f32, Protein, <Dispatch as Backend>::LANES> for Pipeline<Protein, Dis
             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
             Dispatch::Sse2 => Sse2::score_rows_into(pssm, seq.as_ref(), rows, scores),
             #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
-            Dispatch::Neon => Neon::score_rows_into(pssm, seq.as_ref(), rows, scores),
+            Dispatch::Neon => Neon::score_f32_rows_into(pssm, seq.as_ref(), rows, scores),
             _ => <Generic as Score<f32, Protein, <Dispatch as Backend>::LANES>>::score_rows_into(
                 &Generic,
                 pssm,
diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs
index c2bce4b..f8961da 100644
--- a/lightmotif/src/pli/platform/avx2.rs
+++ b/lightmotif/src/pli/platform/avx2.rs
@@ -273,6 +273,7 @@ unsafe fn score_f32_avx2_gather<A>(
     _mm_sfence();
 }
 
+#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
 #[target_feature(enable = "avx2")]
 pub unsafe fn score_u8_avx2_shuffle<A>(
     pssm: &DenseMatrix<u8, A::K>,
diff --git a/lightmotif/src/pli/platform/neon.rs b/lightmotif/src/pli/platform/neon.rs
index 21cf142..b03874d 100644
--- a/lightmotif/src/pli/platform/neon.rs
+++ b/lightmotif/src/pli/platform/neon.rs
@@ -12,6 +12,7 @@ use std::ops::Rem;
 use super::Backend;
 use crate::abc::Alphabet;
 use crate::abc::Symbol;
+use crate::dense::DenseMatrix;
 use crate::err::InvalidSymbol;
 use crate::num::consts::U16;
 use crate::num::MultipleOf;
@@ -20,9 +21,8 @@ use crate::num::Unsigned;
 use crate::num::Zero;
 use crate::pli::Encode;
 use crate::pli::Pipeline;
-use crate::scores::StripedScores;
-
 use crate::pwm::ScoringMatrix;
+use crate::scores::StripedScores;
 use crate::seq::StripedSequence;
 
 /// A marker type for the SSE2 implementation of the pipeline.
@@ -123,12 +123,14 @@ where
 
 #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
 #[target_feature(enable = "neon")]
-unsafe fn score_neon<A: Alphabet, C: MultipleOf<U16>>(
-    pssm: &ScoringMatrix<A>,
+unsafe fn score_f32_neon<A: Alphabet, C: MultipleOf<U16>>(
+    pssm: &DenseMatrix<f32, A::K>,
     seq: &StripedSequence<A, C>,
     rows: Range<usize>,
-    scores: &mut StripedScores<C>,
+    scores: &mut StripedScores<f32, C>,
 ) {
+    use crate::dense::DenseMatrix;
+
     let zero_u8 = vdupq_n_u8(0);
     let zero_f32 = vdupq_n_f32(0.0);
     // process columns of the striped matrix, any multiple of 16 is supported
@@ -140,9 +142,9 @@ unsafe fn score_neon<A: Alphabet, C: MultipleOf<U16>>(
             let mut s = float32x4x4_t(zero_f32, zero_f32, zero_f32, zero_f32);
             // reset position
             let mut dataptr = seq.matrix()[i].as_ptr().add(offset);
-            let mut pssmptr = pssm.matrix()[0].as_ptr();
+            let mut pssmptr = pssm[0].as_ptr();
             // advance position in the position weight matrix
-            for _ in 0..pssm.len() {
+            for _ in 0..pssm.rows() {
                 // load sequence row and broadcast to f32
                 let x = vld1q_u8(dataptr as *const u8);
                 let z = vzipq_u8(x, zero_u8);
@@ -167,7 +169,7 @@ unsafe fn score_neon<A: Alphabet, C: MultipleOf<U16>>(
                 }
                 // advance to next row in sequence and PSSM matrices
                 dataptr = dataptr.add(seq.matrix().stride());
-                pssmptr = pssmptr.add(pssm.matrix().stride());
+                pssmptr = pssmptr.add(pssm.stride());
             }
             // record the score for the current position
             let row = &mut data[i];
@@ -194,7 +196,7 @@ impl Neon {
     }
 
     #[allow(unused)]
-    pub fn score_rows_into<A, C, S, M>(
+    pub fn score_f32_rows_into<A, C, S, M>(
         pssm: M,
         seq: S,
         rows: Range<usize>,
@@ -203,27 +205,27 @@ impl Neon {
         A: Alphabet,
         C: MultipleOf<U16>,
         S: AsRef<StripedSequence<A, C>>,
-        M: AsRef<ScoringMatrix<A>>,
+        M: AsRef<DenseMatrix<f32, A::K>>,
     {
         let seq = seq.as_ref();
         let pssm = pssm.as_ref();
 
-        if seq.wrap() < pssm.len() - 1 {
+        if seq.wrap() < pssm.rows() - 1 {
             panic!(
                 "not enough wrapping rows for motif of length {}",
-                pssm.len()
+                pssm.rows()
             );
         }
 
-        if seq.len() < pssm.len() || rows.len() == 0 {
+        if seq.len() < pssm.rows() || rows.len() == 0 {
             scores.resize(0, 0);
             return;
         }
 
-        scores.resize(rows.len(), (seq.len() + 1).saturating_sub(pssm.len()));
+        scores.resize(rows.len(), (seq.len() + 1).saturating_sub(pssm.rows()));
         #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
         unsafe {
-            score_neon(pssm, seq, rows, scores);
+            score_f32_neon(pssm, seq, rows, scores);
         }
         #[cfg(not(any(target_arch = "arm", target_arch = "aarch64")))]
         panic!("attempting to run NEON code on a non-Arm host")
-- 
GitLab