From b167120d0c59c19cd15ecae1d2a9e4538ed16ead Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Fri, 21 Jun 2024 11:20:02 +0200 Subject: [PATCH] Add SSE2 implementation of the `Encode` pipeline operation --- lightmotif/src/pli/mod.rs | 11 +++- lightmotif/src/pli/platform/avx2.rs | 14 +++-- lightmotif/src/pli/platform/sse2.rs | 84 +++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+), 6 deletions(-) diff --git a/lightmotif/src/pli/mod.rs b/lightmotif/src/pli/mod.rs index 056f933..b6e8f4e 100644 --- a/lightmotif/src/pli/mod.rs +++ b/lightmotif/src/pli/mod.rs @@ -323,7 +323,16 @@ impl<A: Alphabet> Pipeline<A, Sse2> { } } -impl<A: Alphabet> Encode<A> for Pipeline<A, Sse2> {} +impl<A: Alphabet> Encode<A> for Pipeline<A, Sse2> { + #[inline] + fn encode_into<S: AsRef<[u8]>>( + &self, + seq: S, + dst: &mut [A::Symbol], + ) -> Result<(), InvalidSymbol> { + Sse2::encode_into::<A>(seq.as_ref(), dst) + } +} impl<A, C> Score<f32, A, C> for Pipeline<A, Sse2> where diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs index 1e731f0..0f5efe2 100644 --- a/lightmotif/src/pli/platform/avx2.rs +++ b/lightmotif/src/pli/platform/avx2.rs @@ -41,6 +41,8 @@ unsafe fn encode_into_avx2<A>(seq: &[u8], dst: &mut [A::Symbol]) -> Result<(), I where A: Alphabet, { + const STRIDE: usize = std::mem::size_of::<__m256i>(); + let alphabet = A::as_str().as_bytes(); let g = Pipeline::<A, _>::generic(); let l = seq.len(); @@ -56,7 +58,7 @@ where let mut error = _mm256_setzero_si256(); // Process the beginning of the sequence in SIMD while possible. - while i + std::mem::size_of::<__m256i>() < l { + while i + STRIDE <= l { // Load current row and reset buffers for the encoded result. let letters = _mm256_loadu_si256(src_ptr as *const __m256i); let mut encoded = _mm256_set1_epi8(A::K::USIZE as i8); @@ -74,9 +76,9 @@ where // Store the encoded result to the output buffer. _mm256_storeu_si256(dst_ptr as *mut __m256i, encoded); // Advance to the next addresses in input and output. - src_ptr = src_ptr.add(std::mem::size_of::<__m256i>()); - dst_ptr = dst_ptr.add(std::mem::size_of::<__m256i>()); - i += std::mem::size_of::<__m256i>(); + src_ptr = src_ptr.add(STRIDE); + dst_ptr = dst_ptr.add(STRIDE); + i += STRIDE; } // If an invalid symbol was encountered, recover which one. @@ -88,7 +90,9 @@ where } // Encode the rest of the sequence using the generic implementation. - g.encode_into(&seq[i..], &mut dst[i..])?; + if i < l { + g.encode_into(&seq[i..], &mut dst[i..])?; + } } Ok(()) diff --git a/lightmotif/src/pli/platform/sse2.rs b/lightmotif/src/pli/platform/sse2.rs index 4609440..212da5b 100644 --- a/lightmotif/src/pli/platform/sse2.rs +++ b/lightmotif/src/pli/platform/sse2.rs @@ -11,12 +11,16 @@ use std::ops::Rem; use super::Backend; use crate::abc::Alphabet; +use crate::abc::Symbol; use crate::dense::DenseMatrix; use crate::dense::MatrixCoordinates; +use crate::err::InvalidSymbol; use crate::num::consts::U16; use crate::num::MultipleOf; use crate::num::StrictlyPositive; use crate::num::Unsigned; +use crate::pli::Encode; +use crate::pli::Pipeline; use crate::pwm::ScoringMatrix; use crate::scores::StripedScores; use crate::seq::StripedSequence; @@ -29,6 +33,73 @@ impl Backend for Sse2 { type LANES = U16; } +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[target_feature(enable = "sse2")] +#[allow(overflowing_literals)] +unsafe fn encode_into_sse2<A>(seq: &[u8], dst: &mut [A::Symbol]) -> Result<(), InvalidSymbol> +where + A: Alphabet, +{ + const STRIDE: usize = std::mem::size_of::<__m128i>(); + + let alphabet = A::as_str().as_bytes(); + let g = Pipeline::<A, _>::generic(); + let l = seq.len(); + assert_eq!(seq.len(), dst.len()); + + unsafe { + // Use raw pointers since we cannot be sure `seq` and `dst` are aligned. + let mut i = 0; + let mut src_ptr = seq.as_ptr(); + let mut dst_ptr = dst.as_mut_ptr(); + + // Store a flag to know if invalid letters have been encountered. + let mut error = _mm_setzero_si128(); + + // Process the beginning of the sequence in SIMD while possible. + while i + STRIDE < l { + // Load current row and reset buffers for the encoded result. + let letters = _mm_loadu_si128(src_ptr as *const __m128i); + let mut encoded = _mm_set1_epi8((A::K::USIZE - 1) as i8); + let mut unknown = _mm_set1_epi8(0xFF); + // Check symbols one by one and match them to the letters. + for a in 0..A::K::USIZE { + let index = _mm_set1_epi8(a as i8); + let ascii = _mm_set1_epi8(alphabet[a] as i8); + let m = _mm_cmpeq_epi8(letters, ascii); + encoded = _mm_or_si128(_mm_andnot_si128(m, encoded), _mm_and_si128(m, index)); + unknown = _mm_andnot_si128(m, unknown); + } + // Record is some symbols of the current vector are unknown. + error = _mm_or_si128(error, unknown); + // Store the encoded result to the output buffer. + _mm_storeu_si128(dst_ptr as *mut __m128i, encoded); + // Advance to the next addresses in input and output. + src_ptr = src_ptr.add(STRIDE); + dst_ptr = dst_ptr.add(STRIDE); + i += STRIDE; + } + + // If an invalid symbol was encountered, recover which one. + // FIXME: run a vectorize the error search? + let mut x: [u8; 16] = [0; 16]; + _mm_storeu_si128(x.as_mut_ptr() as *mut __m128i, error); + if x.iter().any(|&x| x != 0) { + for i in 0..l { + let _ = A::Symbol::from_ascii(seq[i])?; + } + } + + // Encode the rest of the sequence using the generic implementation. + // g.encode_into(&seq[i..], &mut dst[i..])?; + if i < l { + g.encode_into(&seq[i..], &mut dst[i..])?; + } + } + + Ok(()) +} + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[target_feature(enable = "sse2")] unsafe fn score_sse2<A: Alphabet, C: MultipleOf<<Sse2 as Backend>::LANES>>( @@ -184,6 +255,19 @@ unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES>>( } impl Sse2 { + #[allow(unused)] + pub fn encode_into<A>(seq: &[u8], dst: &mut [A::Symbol]) -> Result<(), InvalidSymbol> + where + A: Alphabet, + { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + unsafe { + encode_into_sse2::<A>(seq, dst) + } + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] + panic!("attempting to run SSE2 code on a non-x86 host"); + } + #[allow(unused)] pub fn score_rows_into<A, C, S, M>( pssm: M, -- GitLab