From eba9e7d160a78dbc3d4cb999cc3141b76f20c8a8 Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Fri, 21 Jun 2024 00:55:31 +0200 Subject: [PATCH] Implement `Score<u8, Dna>` for the NEON platform --- lightmotif/src/pli/dispatch.rs | 5 +- lightmotif/src/pli/mod.rs | 32 ++++++++++- lightmotif/src/pli/platform/avx2.rs | 4 +- lightmotif/src/pli/platform/neon.rs | 86 ++++++++++++++++++++++++++++- 4 files changed, 119 insertions(+), 8 deletions(-) diff --git a/lightmotif/src/pli/dispatch.rs b/lightmotif/src/pli/dispatch.rs index 7e93906..f6a4010 100644 --- a/lightmotif/src/pli/dispatch.rs +++ b/lightmotif/src/pli/dispatch.rs @@ -148,8 +148,8 @@ impl Score<u8, Dna, <Dispatch as Backend>::LANES> for Pipeline<Dna, Dispatch> { Dispatch::Avx2 => Avx2::score_u8_rows_into_shuffle(pssm, seq.as_ref(), rows, scores), // #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] // Dispatch::Sse2 => Sse2::score_rows_into(pssm, seq.as_ref(), rows, scores), - // #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] - // Dispatch::Neon => Neon::score_rows_into(pssm, seq.as_ref(), rows, scores), + #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] + Dispatch::Neon => Neon::score_u8_rows_into(pssm, seq.as_ref(), rows, scores), _ => <Generic as Score<u8, Dna, <Dispatch as Backend>::LANES>>::score_rows_into( &Generic, pssm, @@ -198,6 +198,7 @@ impl<A: Alphabet> Maximum<u8, <Dispatch as Backend>::LANES> for Pipeline<A, Disp scores: &StripedScores<u8, <Dispatch as Backend>::LANES>, ) -> Option<MatrixCoordinates> { match self.backend { + // FIXME !!!! #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] Dispatch::Avx2 => Avx2::argmax_u8(scores), // #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] diff --git a/lightmotif/src/pli/mod.rs b/lightmotif/src/pli/mod.rs index 25a4513..056f933 100644 --- a/lightmotif/src/pli/mod.rs +++ b/lightmotif/src/pli/mod.rs @@ -4,6 +4,8 @@ use std::ops::Add; use std::ops::AddAssign; use std::ops::Range; +use typenum::IsLessOrEqual; + use crate::abc::Alphabet; use crate::abc::Dna; use crate::abc::Protein; @@ -16,7 +18,6 @@ use crate::err::UnsupportedBackend; use crate::num::MultipleOf; use crate::num::StrictlyPositive; use crate::num::U16; -use crate::pwm::ScoringMatrix; use crate::scores::StripedScores; use crate::seq::EncodedSequence; use crate::seq::StripedSequence; @@ -523,13 +524,38 @@ where A: Alphabet, C: StrictlyPositive + MultipleOf<U16>, { + #[inline] + fn score_rows_into<S, M>( + &self, + pssm: M, + seq: S, + rows: Range<usize>, + scores: &mut StripedScores<f32, C>, + ) where + S: AsRef<StripedSequence<A, C>>, + M: AsRef<DenseMatrix<f32, <A as Alphabet>::K>>, + { + Neon::score_f32_rows_into(pssm, seq, rows, scores); + } } -impl<A, C> Score<u8, A, C> for Pipeline<A, Neon> +impl<C> Score<u8, Dna, C> for Pipeline<Dna, Neon> where - A: Alphabet, C: StrictlyPositive + MultipleOf<U16>, { + #[inline] + fn score_rows_into<S, M>( + &self, + pssm: M, + seq: S, + rows: Range<usize>, + scores: &mut StripedScores<u8, C>, + ) where + S: AsRef<StripedSequence<Dna, C>>, + M: AsRef<DenseMatrix<u8, <Dna as Alphabet>::K>>, + { + Neon::score_u8_rows_into(pssm, seq, rows, scores); + } } impl<A, C> Maximum<f32, C> for Pipeline<A, Neon> diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs index f8961da..8f90e9d 100644 --- a/lightmotif/src/pli/platform/avx2.rs +++ b/lightmotif/src/pli/platform/avx2.rs @@ -301,9 +301,9 @@ pub unsafe fn score_u8_avx2_shuffle<A>( // because in AVX2 shuffle operates on the two halves independently. let t = _mm256_broadcastsi128_si256(_mm_load_si128(&*(pssmptr as *const __m128i))); // load scores for given sequence - let x = _mm256_shuffle_epi8(t, x); + let y = _mm256_shuffle_epi8(t, x); // add scores to the running sum - s = _mm256_adds_epu8(s, x); + s = _mm256_adds_epu8(s, y); // advance to next row in PSSM and sequence matrices seqptr = seqptr.add(seq.matrix().stride()); pssmptr = pssmptr.add(pssm.stride()); diff --git a/lightmotif/src/pli/platform/neon.rs b/lightmotif/src/pli/platform/neon.rs index 2c63635..66f38bf 100644 --- a/lightmotif/src/pli/platform/neon.rs +++ b/lightmotif/src/pli/platform/neon.rs @@ -14,11 +14,13 @@ use crate::abc::Alphabet; use crate::abc::Symbol; use crate::dense::DenseMatrix; use crate::err::InvalidSymbol; -use crate::num::consts::U16; +use crate::num::IsLessOrEqual; use crate::num::MultipleOf; +use crate::num::NonZero; use crate::num::StrictlyPositive; use crate::num::Unsigned; use crate::num::Zero; +use crate::num::U16; use crate::pli::Encode; use crate::pli::Pipeline; use crate::pwm::ScoringMatrix; @@ -180,6 +182,50 @@ unsafe fn score_f32_neon<A: Alphabet, C: MultipleOf<U16>>( } } +#[cfg(any(target_arch = "arm", target_arch = "aarch64"))] +#[target_feature(enable = "neon")] +unsafe fn score_u8_neon<A: Alphabet, C: MultipleOf<U16>>( + pssm: &DenseMatrix<u8, A::K>, + seq: &StripedSequence<A, C>, + rows: Range<usize>, + scores: &mut StripedScores<u8, C>, +) { + use crate::dense::DenseMatrix; + + let zero_u8 = vdupq_n_u8(0); + let zero_f32 = vdupq_n_f32(0.0); + // process columns of the striped matrix, any multiple of 16 is supported + let data = scores.matrix_mut(); + for offset in (0..C::Quotient::USIZE).map(|i| i * <Neon as Backend>::LANES::USIZE) { + let mut rowptr = data[0].as_mut_ptr().add(offset); + // process every position of the sequence data + for i in rows.clone() { + // reset sums for current position + let mut s = vdupq_n_u8(0); + // reset position + let mut seqptr = seq.matrix()[i].as_ptr().add(offset); + let mut pssmptr = pssm[0].as_ptr(); + // advance position in the position weight matrix + for _ in 0..pssm.rows() { + // load sequence row + let x = vld1q_u8(seqptr as *const u8); + // load pssm row + let t = vld1q_u8(pssmptr as *const u8); + // shuffle pssm with the sequence characters + let y = vqtbl1q_u8(t, x); + // add scores to the running sum + s = vaddq_u8(s, y); + // advance to next row in PSSM and sequence matrices + seqptr = seqptr.add(seq.matrix().stride()); + pssmptr = pssmptr.add(pssm.stride()); + } + // record the score for the current position + vst1q_u8(rowptr, s); + rowptr = rowptr.add(data.stride()); + } + } +} + impl Neon { #[allow(unused)] pub fn encode_into<A>(seq: &[u8], dst: &mut [A::Symbol]) -> Result<(), InvalidSymbol> @@ -232,4 +278,42 @@ impl Neon { #[cfg(not(any(target_arch = "arm", target_arch = "aarch64")))] panic!("attempting to run NEON code on a non-Arm host") } + + #[allow(unused)] + pub fn score_u8_rows_into<A, C, S, M>( + pssm: M, + seq: S, + rows: Range<usize>, + scores: &mut StripedScores<u8, C>, + ) where + <A as Alphabet>::K: IsLessOrEqual<U16>, + <<A as Alphabet>::K as IsLessOrEqual<U16>>::Output: NonZero, + A: Alphabet, + C: MultipleOf<U16>, + S: AsRef<StripedSequence<A, C>>, + M: AsRef<DenseMatrix<u8, A::K>>, + { + let seq = seq.as_ref(); + let pssm = pssm.as_ref(); + + if seq.wrap() < pssm.rows() - 1 { + panic!( + "not enough wrapping rows for motif of length {}", + pssm.rows() + ); + } + + if seq.len() < pssm.rows() || rows.len() == 0 { + scores.resize(0, 0); + return; + } + + scores.resize(rows.len(), (seq.len() + 1).saturating_sub(pssm.rows())); + #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] + unsafe { + score_u8_neon(pssm, seq, rows, scores); + } + #[cfg(not(any(target_arch = "arm", target_arch = "aarch64")))] + panic!("attempting to run NEON code on a non-Arm host") + } } -- GitLab