diff --git a/lightmotif-bench/dna.rs b/lightmotif-bench/dna.rs index e2c96ae21385ac6a175c9f4fe4573bbab210472c..906f5b2b99f039b47a7e7927c658a5b4e7d1b898 100644 --- a/lightmotif-bench/dna.rs +++ b/lightmotif-bench/dna.rs @@ -67,7 +67,7 @@ fn bench_scanner_best(bencher: &mut test::Bencher) { striped.configure(&pssm); let mut best = 0; - bencher.iter(|| best = Scanner::new(&pssm, &striped).best().unwrap().position); + bencher.iter(|| best = Scanner::new(&pssm, &striped).max().unwrap().position); bencher.bytes = seq.len() as u64; println!("best: {:?}", best); diff --git a/lightmotif/benches/score.rs b/lightmotif/benches/score.rs index 44ee5df5d76318a93d05729462b15eb02bdff1cc..a1fee2613019ea1cfa49734104493f4ff83a8b86 100644 --- a/lightmotif/benches/score.rs +++ b/lightmotif/benches/score.rs @@ -10,6 +10,7 @@ use lightmotif::pli::Pipeline; use lightmotif::pli::Score; use lightmotif::pli::Stripe; use lightmotif::pwm::CountMatrix; +use lightmotif::scan::DefaultColumns; use lightmotif::scores::StripedScores; use lightmotif::seq::EncodedSequence; diff --git a/lightmotif/benches/stripe.rs b/lightmotif/benches/stripe.rs index 65d90365ea98feac4f98032f23178ce7e4e326c0..61c055c28f051886de0ff1bd682a4db1cc6e0697 100644 --- a/lightmotif/benches/stripe.rs +++ b/lightmotif/benches/stripe.rs @@ -6,6 +6,7 @@ extern crate test; use lightmotif::num::U32; use lightmotif::pli::Pipeline; use lightmotif::pli::Stripe; +use lightmotif::scan::DefaultColumns; use lightmotif::seq::EncodedSequence; mod dna { @@ -15,7 +16,7 @@ mod dna { const SEQUENCE: &str = include_str!("ecoli.txt"); - fn bench<P: Stripe<Dna, U32>>(bencher: &mut test::Bencher, pli: &P) { + fn bench<P: Stripe<Dna, DefaultColumns>>(bencher: &mut test::Bencher, pli: &P) { let seq = EncodedSequence::encode(SEQUENCE).unwrap(); let mut dst = seq.to_striped(); diff --git a/lightmotif/src/pli/mod.rs b/lightmotif/src/pli/mod.rs index e53a32d8fafb33acc5f649f6db28d30da9059d2d..25a4513168711b22dda25e6c69c7e024070a2592 100644 --- a/lightmotif/src/pli/mod.rs +++ b/lightmotif/src/pli/mod.rs @@ -525,6 +525,13 @@ where { } +impl<A, C> Score<u8, A, C> for Pipeline<A, Neon> +where + A: Alphabet, + C: StrictlyPositive + MultipleOf<U16>, +{ +} + impl<A, C> Maximum<f32, C> for Pipeline<A, Neon> where A: Alphabet, @@ -532,6 +539,13 @@ where { } +impl<A, C> Maximum<u8, C> for Pipeline<A, Neon> +where + A: Alphabet, + C: StrictlyPositive + MultipleOf<U16>, +{ +} + impl<A, C> Threshold<f32, C> for Pipeline<A, Neon> where A: Alphabet, @@ -539,6 +553,13 @@ where { } +impl<A, C> Threshold<u8, C> for Pipeline<A, Neon> +where + A: Alphabet, + C: StrictlyPositive + MultipleOf<U16>, +{ +} + // -- Tests -------------------------------------------------------------------- #[cfg(test)] diff --git a/lightmotif/src/pli/platform/neon.rs b/lightmotif/src/pli/platform/neon.rs index b03874debacafe3e459ebdfd9ca8d8341b996884..2c636356866233249859bc62675cfe9db2c1f622 100644 --- a/lightmotif/src/pli/platform/neon.rs +++ b/lightmotif/src/pli/platform/neon.rs @@ -136,6 +136,7 @@ unsafe fn score_f32_neon<A: Alphabet, C: MultipleOf<U16>>( // process columns of the striped matrix, any multiple of 16 is supported let data = scores.matrix_mut(); for offset in (0..C::Quotient::USIZE).map(|i| i * <Neon as Backend>::LANES::USIZE) { + let mut rowptr = data[0].as_mut_ptr().add(offset); // process every position of the sequence data for i in rows.clone() { // reset sums for current position @@ -145,9 +146,10 @@ unsafe fn score_f32_neon<A: Alphabet, C: MultipleOf<U16>>( let mut pssmptr = pssm[0].as_ptr(); // advance position in the position weight matrix for _ in 0..pssm.rows() { - // load sequence row and broadcast to f32 + // load sequence row let x = vld1q_u8(dataptr as *const u8); let z = vzipq_u8(x, zero_u8); + // transform u8 into u32 let lo = vzipq_u8(z.0, zero_u8); let hi = vzipq_u8(z.1, zero_u8); let x1 = vreinterpretq_u32_u8(lo.0); @@ -172,8 +174,8 @@ unsafe fn score_f32_neon<A: Alphabet, C: MultipleOf<U16>>( pssmptr = pssmptr.add(pssm.stride()); } // record the score for the current position - let row = &mut data[i]; - vst1q_f32_x4(row[offset..].as_mut_ptr(), s); + vst1q_f32_x4(rowptr, s); + rowptr = rowptr.add(data.stride()); } } } diff --git a/lightmotif/src/scan.rs b/lightmotif/src/scan.rs index 498566758c9e813d0f9b52d63a5b0a796a80371c..6a5428ba6fa0cc0180d815133b6545285212baca 100644 --- a/lightmotif/src/scan.rs +++ b/lightmotif/src/scan.rs @@ -1,16 +1,37 @@ //! Scanner implementation using a fixed block size for the scores. +use std::cmp::Ordering; + use super::abc::Alphabet; use super::pli::dispatch::Dispatch; +use super::pli::platform::Avx2; use super::pli::platform::Backend; +use super::pli::platform::Neon; use super::pwm::ScoringMatrix; use super::seq::StripedSequence; +use crate::dense::DenseMatrix; +use crate::num::StrictlyPositive; +use crate::num::U32; use crate::pli::Maximum; use crate::pli::Pipeline; use crate::pli::Score; use crate::pli::Threshold; +use crate::pwm::DiscreteMatrix; use crate::scores::StripedScores; -type C = <Dispatch as Backend>::LANES; +#[cfg(target_arch = "x86_64")] +type _DefaultColumns = typenum::consts::U32; +#[cfg(any(target_arch = "x86", target_arch = "arm", target_arch = "aarch64"))] +type _DefaultColumns = typenum::consts::U16; +#[cfg(not(any( + target_arch = "x86", + target_arch = "x86_64", + target_arch = "arm", + target_arch = "aarch64" +)))] +type _DefaultColumns = typenum::consts::U1; + +/// The default column number used in scanners. +pub type DefaultColumns = _DefaultColumns; #[derive(Debug)] enum CowMut<'a, T> { @@ -44,7 +65,7 @@ impl<T: Default> Default for CowMut<'_, T> { } /// A hit describing a scored position somewhere in the sequence. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Hit { pub position: usize, pub score: f32, @@ -53,15 +74,35 @@ pub struct Hit { impl Hit { /// Create a new hit. pub fn new(position: usize, score: f32) -> Self { + assert!(!score.is_nan()); Self { position, score } } } +impl Eq for Hit {} + +impl PartialOrd for Hit { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + match self.score.partial_cmp(&other.score)? { + Ordering::Equal => self.position.partial_cmp(&other.position), + other => Some(other), + } + } +} + +impl Ord for Hit { + fn cmp(&self, other: &Self) -> Ordering { + self.partial_cmp(other).unwrap() + } +} + #[derive(Debug)] -pub struct Scanner<'a, A: Alphabet> { +pub struct Scanner<'a, A: Alphabet, C: StrictlyPositive = DefaultColumns> { pssm: &'a ScoringMatrix<A>, + dm: DiscreteMatrix<A>, seq: &'a StripedSequence<A, C>, scores: CowMut<'a, StripedScores<f32, C>>, + dscores: StripedScores<u8, C>, threshold: f32, block_size: usize, row: usize, @@ -69,15 +110,17 @@ pub struct Scanner<'a, A: Alphabet> { pipeline: Pipeline<A, Dispatch>, } -impl<'a, A: Alphabet> Scanner<'a, A> { +impl<'a, A: Alphabet, C: StrictlyPositive> Scanner<'a, A, C> { /// Create a new scanner for the given matrix and sequence. pub fn new(pssm: &'a ScoringMatrix<A>, seq: &'a StripedSequence<A, C>) -> Self { Self { pssm, seq, + dm: pssm.to_discrete(), scores: CowMut::Owned(StripedScores::empty()), + dscores: StripedScores::empty(), threshold: 0.0, - block_size: 512, + block_size: 256, row: 0, hits: Vec::new(), pipeline: Pipeline::dispatch(), @@ -103,56 +146,60 @@ impl<'a, A: Alphabet> Scanner<'a, A> { } } -impl<'a, A: Alphabet> Scanner<'a, A> +impl<'a, A, C> Iterator for Scanner<'a, A, C> where - Pipeline<A, Dispatch>: Score<f32, A, C> + Maximum<f32, C>, -{ - /// Consume the scanner to find the best hit. - pub fn best(&mut self) -> Option<Hit> { - let pli = Pipeline::dispatch(); - let mut best = std::mem::take(&mut self.hits) - .into_iter() - .max_by(|x, y| x.score.partial_cmp(&y.score).unwrap()); - while self.row < self.seq.matrix().rows() { - let end = (self.row + self.block_size).min(self.seq.matrix().rows() - self.seq.wrap()); - pli.score_rows_into(&self.pssm, &self.seq, self.row..end, &mut self.scores); - let matrix = self.scores.matrix(); - if let Some(c) = pli.argmax(&self.scores) { - let score = matrix[c]; - if best - .as_ref() - .map(|hit: &Hit| matrix[c] >= hit.score) - .unwrap_or(true) - { - let index = - c.col * (self.seq.matrix().rows() - self.seq.wrap()) + self.row + c.row; - best = Some(Hit::new(index, score)); - } - } - self.row += self.block_size; - } - best - } -} - -impl<'a, A: Alphabet> Iterator for Scanner<'a, A> -where - Pipeline<A, Dispatch>: Score<f32, A, C> + Threshold<f32, C>, + A: Alphabet, + C: StrictlyPositive, + Pipeline<A, Dispatch>: Score<u8, A, C> + Threshold<u8, C> + Maximum<u8, C>, { type Item = Hit; fn next(&mut self) -> Option<Self::Item> { + let t = self.dm.scale(self.threshold); while self.hits.is_empty() && self.row < self.seq.matrix().rows() { + // compute the row slice to score in the striped sequence matrix let end = (self.row + self.block_size) .min(self.seq.matrix().rows().saturating_sub(self.seq.wrap())); + // score the row slice self.pipeline - .score_rows_into(&self.pssm, &self.seq, self.row..end, &mut self.scores); - let matrix = self.scores.matrix(); - for c in self.pipeline.threshold(&self.scores, self.threshold) { + .score_rows_into(&self.dm, &self.seq, self.row..end, &mut self.dscores); + // scan through the positions above discrete threshold and recompute + // scores in floating-point to see if they pass the real threshold. + for c in self.pipeline.threshold(&self.dscores, t) { let index = c.col * (self.seq.matrix().rows() - self.seq.wrap()) + self.row + c.row; - self.hits.push(Hit::new(index, matrix[c])); + let score = self.pssm.score_position(&self.seq, index); + if score >= self.threshold { + self.hits.push(Hit::new(index, score)); + } } self.row += self.block_size; } self.hits.pop() } + + fn max(mut self) -> Option<Self::Item> { + let mut best = std::mem::take(&mut self.hits) + .into_iter() + .max_by(|x, y| x.score.partial_cmp(&y.score).unwrap()) + .unwrap_or(Hit::new(0, f32::MIN)); + let mut best_discrete = self.dm.scale(best.score); + while self.row < self.seq.matrix().rows() { + let end = (self.row + self.block_size).min(self.seq.matrix().rows() - self.seq.wrap()); + self.pipeline + .score_rows_into(&self.dm, &self.seq, self.row..end, &mut self.dscores); + if let Some(c) = self.pipeline.argmax(&self.dscores) { + let dscore = self.dscores.matrix()[c]; + if dscore >= best_discrete { + let index = + c.col * (self.seq.matrix().rows() - self.seq.wrap()) + self.row + c.row; + let score = self.pssm.score_position(&self.seq, index); + if (score > best.score) | (score == best.score && index > best.position) { + best = Hit::new(index, score); + best_discrete = dscore; + } + } + } + self.row += self.block_size; + } + Some(best) + } } diff --git a/lightmotif/tests/dna.rs b/lightmotif/tests/dna.rs index 989dec11e4b00fc5e2fd987a807d1c59d329b5bc..c5fb96d4a152bbe6518d5526b2e1f06bb573aa6d 100644 --- a/lightmotif/tests/dna.rs +++ b/lightmotif/tests/dna.rs @@ -232,25 +232,25 @@ mod dispatch { #[test] fn test_score() { let pli = Pipeline::dispatch(); - super::test_score::<U32, _>(&pli); + super::test_score(&pli); } #[test] fn test_score_rows() { let pli = Pipeline::dispatch(); - super::test_score_rows::<U32, _>(&pli); + super::test_score_rows(&pli); } #[test] fn test_argmax() { let pli = Pipeline::dispatch(); - super::test_argmax::<U32, _>(&pli); + super::test_argmax(&pli); } #[test] fn test_threshold() { let pli = Pipeline::dispatch(); - super::test_threshold::<U32, _>(&pli); + super::test_threshold(&pli); } }