Skip to content
Snippets Groups Projects
Commit b167120d authored by Martin Larralde's avatar Martin Larralde
Browse files

Add SSE2 implementation of the `Encode` pipeline operation

parent f355192d
No related branches found
No related tags found
No related merge requests found
......@@ -323,7 +323,16 @@ impl<A: Alphabet> Pipeline<A, Sse2> {
}
}
impl<A: Alphabet> Encode<A> for Pipeline<A, Sse2> {}
impl<A: Alphabet> Encode<A> for Pipeline<A, Sse2> {
#[inline]
fn encode_into<S: AsRef<[u8]>>(
&self,
seq: S,
dst: &mut [A::Symbol],
) -> Result<(), InvalidSymbol> {
Sse2::encode_into::<A>(seq.as_ref(), dst)
}
}
impl<A, C> Score<f32, A, C> for Pipeline<A, Sse2>
where
......
......@@ -41,6 +41,8 @@ unsafe fn encode_into_avx2<A>(seq: &[u8], dst: &mut [A::Symbol]) -> Result<(), I
where
A: Alphabet,
{
const STRIDE: usize = std::mem::size_of::<__m256i>();
let alphabet = A::as_str().as_bytes();
let g = Pipeline::<A, _>::generic();
let l = seq.len();
......@@ -56,7 +58,7 @@ where
let mut error = _mm256_setzero_si256();
// Process the beginning of the sequence in SIMD while possible.
while i + std::mem::size_of::<__m256i>() < l {
while i + STRIDE <= l {
// Load current row and reset buffers for the encoded result.
let letters = _mm256_loadu_si256(src_ptr as *const __m256i);
let mut encoded = _mm256_set1_epi8(A::K::USIZE as i8);
......@@ -74,9 +76,9 @@ where
// Store the encoded result to the output buffer.
_mm256_storeu_si256(dst_ptr as *mut __m256i, encoded);
// Advance to the next addresses in input and output.
src_ptr = src_ptr.add(std::mem::size_of::<__m256i>());
dst_ptr = dst_ptr.add(std::mem::size_of::<__m256i>());
i += std::mem::size_of::<__m256i>();
src_ptr = src_ptr.add(STRIDE);
dst_ptr = dst_ptr.add(STRIDE);
i += STRIDE;
}
// If an invalid symbol was encountered, recover which one.
......@@ -88,7 +90,9 @@ where
}
// Encode the rest of the sequence using the generic implementation.
g.encode_into(&seq[i..], &mut dst[i..])?;
if i < l {
g.encode_into(&seq[i..], &mut dst[i..])?;
}
}
Ok(())
......
......@@ -11,12 +11,16 @@ use std::ops::Rem;
use super::Backend;
use crate::abc::Alphabet;
use crate::abc::Symbol;
use crate::dense::DenseMatrix;
use crate::dense::MatrixCoordinates;
use crate::err::InvalidSymbol;
use crate::num::consts::U16;
use crate::num::MultipleOf;
use crate::num::StrictlyPositive;
use crate::num::Unsigned;
use crate::pli::Encode;
use crate::pli::Pipeline;
use crate::pwm::ScoringMatrix;
use crate::scores::StripedScores;
use crate::seq::StripedSequence;
......@@ -29,6 +33,73 @@ impl Backend for Sse2 {
type LANES = U16;
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
#[allow(overflowing_literals)]
unsafe fn encode_into_sse2<A>(seq: &[u8], dst: &mut [A::Symbol]) -> Result<(), InvalidSymbol>
where
A: Alphabet,
{
const STRIDE: usize = std::mem::size_of::<__m128i>();
let alphabet = A::as_str().as_bytes();
let g = Pipeline::<A, _>::generic();
let l = seq.len();
assert_eq!(seq.len(), dst.len());
unsafe {
// Use raw pointers since we cannot be sure `seq` and `dst` are aligned.
let mut i = 0;
let mut src_ptr = seq.as_ptr();
let mut dst_ptr = dst.as_mut_ptr();
// Store a flag to know if invalid letters have been encountered.
let mut error = _mm_setzero_si128();
// Process the beginning of the sequence in SIMD while possible.
while i + STRIDE < l {
// Load current row and reset buffers for the encoded result.
let letters = _mm_loadu_si128(src_ptr as *const __m128i);
let mut encoded = _mm_set1_epi8((A::K::USIZE - 1) as i8);
let mut unknown = _mm_set1_epi8(0xFF);
// Check symbols one by one and match them to the letters.
for a in 0..A::K::USIZE {
let index = _mm_set1_epi8(a as i8);
let ascii = _mm_set1_epi8(alphabet[a] as i8);
let m = _mm_cmpeq_epi8(letters, ascii);
encoded = _mm_or_si128(_mm_andnot_si128(m, encoded), _mm_and_si128(m, index));
unknown = _mm_andnot_si128(m, unknown);
}
// Record is some symbols of the current vector are unknown.
error = _mm_or_si128(error, unknown);
// Store the encoded result to the output buffer.
_mm_storeu_si128(dst_ptr as *mut __m128i, encoded);
// Advance to the next addresses in input and output.
src_ptr = src_ptr.add(STRIDE);
dst_ptr = dst_ptr.add(STRIDE);
i += STRIDE;
}
// If an invalid symbol was encountered, recover which one.
// FIXME: run a vectorize the error search?
let mut x: [u8; 16] = [0; 16];
_mm_storeu_si128(x.as_mut_ptr() as *mut __m128i, error);
if x.iter().any(|&x| x != 0) {
for i in 0..l {
let _ = A::Symbol::from_ascii(seq[i])?;
}
}
// Encode the rest of the sequence using the generic implementation.
// g.encode_into(&seq[i..], &mut dst[i..])?;
if i < l {
g.encode_into(&seq[i..], &mut dst[i..])?;
}
}
Ok(())
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
unsafe fn score_sse2<A: Alphabet, C: MultipleOf<<Sse2 as Backend>::LANES>>(
......@@ -184,6 +255,19 @@ unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES>>(
}
impl Sse2 {
#[allow(unused)]
pub fn encode_into<A>(seq: &[u8], dst: &mut [A::Symbol]) -> Result<(), InvalidSymbol>
where
A: Alphabet,
{
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe {
encode_into_sse2::<A>(seq, dst)
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
panic!("attempting to run SSE2 code on a non-x86 host");
}
#[allow(unused)]
pub fn score_rows_into<A, C, S, M>(
pssm: M,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment