Skip to content
Snippets Groups Projects
Commit eba9e7d1 authored by Martin Larralde's avatar Martin Larralde
Browse files

Implement `Score<u8, Dna>` for the NEON platform

parent 0f4807d2
No related branches found
No related tags found
No related merge requests found
......@@ -148,8 +148,8 @@ impl Score<u8, Dna, <Dispatch as Backend>::LANES> for Pipeline<Dna, Dispatch> {
Dispatch::Avx2 => Avx2::score_u8_rows_into_shuffle(pssm, seq.as_ref(), rows, scores),
// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
// Dispatch::Sse2 => Sse2::score_rows_into(pssm, seq.as_ref(), rows, scores),
// #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
// Dispatch::Neon => Neon::score_rows_into(pssm, seq.as_ref(), rows, scores),
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
Dispatch::Neon => Neon::score_u8_rows_into(pssm, seq.as_ref(), rows, scores),
_ => <Generic as Score<u8, Dna, <Dispatch as Backend>::LANES>>::score_rows_into(
&Generic,
pssm,
......@@ -198,6 +198,7 @@ impl<A: Alphabet> Maximum<u8, <Dispatch as Backend>::LANES> for Pipeline<A, Disp
scores: &StripedScores<u8, <Dispatch as Backend>::LANES>,
) -> Option<MatrixCoordinates> {
match self.backend {
// FIXME !!!!
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
Dispatch::Avx2 => Avx2::argmax_u8(scores),
// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
......
......@@ -4,6 +4,8 @@ use std::ops::Add;
use std::ops::AddAssign;
use std::ops::Range;
use typenum::IsLessOrEqual;
use crate::abc::Alphabet;
use crate::abc::Dna;
use crate::abc::Protein;
......@@ -16,7 +18,6 @@ use crate::err::UnsupportedBackend;
use crate::num::MultipleOf;
use crate::num::StrictlyPositive;
use crate::num::U16;
use crate::pwm::ScoringMatrix;
use crate::scores::StripedScores;
use crate::seq::EncodedSequence;
use crate::seq::StripedSequence;
......@@ -523,13 +524,38 @@ where
A: Alphabet,
C: StrictlyPositive + MultipleOf<U16>,
{
#[inline]
fn score_rows_into<S, M>(
&self,
pssm: M,
seq: S,
rows: Range<usize>,
scores: &mut StripedScores<f32, C>,
) where
S: AsRef<StripedSequence<A, C>>,
M: AsRef<DenseMatrix<f32, <A as Alphabet>::K>>,
{
Neon::score_f32_rows_into(pssm, seq, rows, scores);
}
}
impl<A, C> Score<u8, A, C> for Pipeline<A, Neon>
impl<C> Score<u8, Dna, C> for Pipeline<Dna, Neon>
where
A: Alphabet,
C: StrictlyPositive + MultipleOf<U16>,
{
#[inline]
fn score_rows_into<S, M>(
&self,
pssm: M,
seq: S,
rows: Range<usize>,
scores: &mut StripedScores<u8, C>,
) where
S: AsRef<StripedSequence<Dna, C>>,
M: AsRef<DenseMatrix<u8, <Dna as Alphabet>::K>>,
{
Neon::score_u8_rows_into(pssm, seq, rows, scores);
}
}
impl<A, C> Maximum<f32, C> for Pipeline<A, Neon>
......
......@@ -301,9 +301,9 @@ pub unsafe fn score_u8_avx2_shuffle<A>(
// because in AVX2 shuffle operates on the two halves independently.
let t = _mm256_broadcastsi128_si256(_mm_load_si128(&*(pssmptr as *const __m128i)));
// load scores for given sequence
let x = _mm256_shuffle_epi8(t, x);
let y = _mm256_shuffle_epi8(t, x);
// add scores to the running sum
s = _mm256_adds_epu8(s, x);
s = _mm256_adds_epu8(s, y);
// advance to next row in PSSM and sequence matrices
seqptr = seqptr.add(seq.matrix().stride());
pssmptr = pssmptr.add(pssm.stride());
......
......@@ -14,11 +14,13 @@ use crate::abc::Alphabet;
use crate::abc::Symbol;
use crate::dense::DenseMatrix;
use crate::err::InvalidSymbol;
use crate::num::consts::U16;
use crate::num::IsLessOrEqual;
use crate::num::MultipleOf;
use crate::num::NonZero;
use crate::num::StrictlyPositive;
use crate::num::Unsigned;
use crate::num::Zero;
use crate::num::U16;
use crate::pli::Encode;
use crate::pli::Pipeline;
use crate::pwm::ScoringMatrix;
......@@ -180,6 +182,50 @@ unsafe fn score_f32_neon<A: Alphabet, C: MultipleOf<U16>>(
}
}
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
#[target_feature(enable = "neon")]
unsafe fn score_u8_neon<A: Alphabet, C: MultipleOf<U16>>(
pssm: &DenseMatrix<u8, A::K>,
seq: &StripedSequence<A, C>,
rows: Range<usize>,
scores: &mut StripedScores<u8, C>,
) {
use crate::dense::DenseMatrix;
let zero_u8 = vdupq_n_u8(0);
let zero_f32 = vdupq_n_f32(0.0);
// process columns of the striped matrix, any multiple of 16 is supported
let data = scores.matrix_mut();
for offset in (0..C::Quotient::USIZE).map(|i| i * <Neon as Backend>::LANES::USIZE) {
let mut rowptr = data[0].as_mut_ptr().add(offset);
// process every position of the sequence data
for i in rows.clone() {
// reset sums for current position
let mut s = vdupq_n_u8(0);
// reset position
let mut seqptr = seq.matrix()[i].as_ptr().add(offset);
let mut pssmptr = pssm[0].as_ptr();
// advance position in the position weight matrix
for _ in 0..pssm.rows() {
// load sequence row
let x = vld1q_u8(seqptr as *const u8);
// load pssm row
let t = vld1q_u8(pssmptr as *const u8);
// shuffle pssm with the sequence characters
let y = vqtbl1q_u8(t, x);
// add scores to the running sum
s = vaddq_u8(s, y);
// advance to next row in PSSM and sequence matrices
seqptr = seqptr.add(seq.matrix().stride());
pssmptr = pssmptr.add(pssm.stride());
}
// record the score for the current position
vst1q_u8(rowptr, s);
rowptr = rowptr.add(data.stride());
}
}
}
impl Neon {
#[allow(unused)]
pub fn encode_into<A>(seq: &[u8], dst: &mut [A::Symbol]) -> Result<(), InvalidSymbol>
......@@ -232,4 +278,42 @@ impl Neon {
#[cfg(not(any(target_arch = "arm", target_arch = "aarch64")))]
panic!("attempting to run NEON code on a non-Arm host")
}
#[allow(unused)]
pub fn score_u8_rows_into<A, C, S, M>(
pssm: M,
seq: S,
rows: Range<usize>,
scores: &mut StripedScores<u8, C>,
) where
<A as Alphabet>::K: IsLessOrEqual<U16>,
<<A as Alphabet>::K as IsLessOrEqual<U16>>::Output: NonZero,
A: Alphabet,
C: MultipleOf<U16>,
S: AsRef<StripedSequence<A, C>>,
M: AsRef<DenseMatrix<u8, A::K>>,
{
let seq = seq.as_ref();
let pssm = pssm.as_ref();
if seq.wrap() < pssm.rows() - 1 {
panic!(
"not enough wrapping rows for motif of length {}",
pssm.rows()
);
}
if seq.len() < pssm.rows() || rows.len() == 0 {
scores.resize(0, 0);
return;
}
scores.resize(rows.len(), (seq.len() + 1).saturating_sub(pssm.rows()));
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
unsafe {
score_u8_neon(pssm, seq, rows, scores);
}
#[cfg(not(any(target_arch = "arm", target_arch = "aarch64")))]
panic!("attempting to run NEON code on a non-Arm host")
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment