From 41a6d1f729a217e05913b0de52fce605f30bda99 Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Fri, 21 Jun 2024 16:52:48 +0200 Subject: [PATCH] Add tests to ensure `Maximum::argmax` works as expected for all implementations --- lightmotif/src/pli/mod.rs | 4 +- lightmotif/tests/argmax.rs | 142 +++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 lightmotif/tests/argmax.rs diff --git a/lightmotif/src/pli/mod.rs b/lightmotif/src/pli/mod.rs index 1a5f57f..80a053b 100644 --- a/lightmotif/src/pli/mod.rs +++ b/lightmotif/src/pli/mod.rs @@ -3,6 +3,8 @@ use std::ops::AddAssign; use std::ops::Range; +use generic_array::ArrayLength; + use crate::abc::Alphabet; use crate::abc::Dna; use crate::abc::Protein; @@ -361,7 +363,7 @@ where impl<A, C> Maximum<f32, C> for Pipeline<A, Sse2> where A: Alphabet, - C: StrictlyPositive + MultipleOf<U16>, + C: StrictlyPositive + MultipleOf<U16> + ArrayLength, { #[inline] fn argmax(&self, scores: &StripedScores<f32, C>) -> Option<MatrixCoordinates> { diff --git a/lightmotif/tests/argmax.rs b/lightmotif/tests/argmax.rs new file mode 100644 index 0000000..fdab04e --- /dev/null +++ b/lightmotif/tests/argmax.rs @@ -0,0 +1,142 @@ +extern crate lightmotif; + +use lightmotif::abc::Background; +use lightmotif::abc::Dna; +use lightmotif::num::StrictlyPositive; +use lightmotif::num::U1; +use lightmotif::num::U16; +use lightmotif::num::U32; +use lightmotif::pli::Maximum; +use lightmotif::pli::Pipeline; +use lightmotif::pli::Score; +use lightmotif::pli::Stripe; +use lightmotif::pwm::CountMatrix; +use lightmotif::scan::Scanner; +use lightmotif::scores::StripedScores; +use lightmotif::seq::EncodedSequence; + +const SEQUENCE: &str = include_str!("../benches/ecoli.txt"); +const PATTERNS: &[&str] = &["GTTGACCTTATCAAC", "GTTGATCCAGTCAAC"]; +const N: usize = SEQUENCE.len() / 10; + +fn test_argmax_f32<C: StrictlyPositive, P: Maximum<f32, C>>(pli: &P) { + let generic = Pipeline::generic(); + + let seq = &SEQUENCE[..N]; + let encoded = EncodedSequence::<Dna>::encode(seq).unwrap(); + let mut striped = generic.stripe(encoded); + + let bg = Background::<Dna>::uniform(); + let cm = PATTERNS + .iter() + .map(EncodedSequence::encode) + .map(Result::unwrap) + .collect::<Result<CountMatrix<Dna>, _>>() + .unwrap(); + let pbm = cm.to_freq(0.1); + let pssm = pbm.to_scoring(bg); + + striped.configure(&pssm); + let scores: StripedScores<f32, C> = generic.score(pssm, striped); + + let best = scores + .unstripe() + .iter() + .cloned() + .enumerate() + .max_by(|x, y| x.1.partial_cmp(&y.1).unwrap()) + .unwrap(); + let m = pli.argmax(&scores).unwrap(); + assert_eq!(scores.offset(m), best.0); + assert_eq!(scores.matrix()[m], best.1); +} + +mod generic { + use super::*; + + #[test] + fn argmax_f32() { + let pli = Pipeline::<Dna, _>::generic(); + super::test_argmax_f32::<U1, _>(&pli); + super::test_argmax_f32::<U16, _>(&pli); + super::test_argmax_f32::<U32, _>(&pli); + } +} + +mod dispatch { + use super::*; + + #[test] + fn argmax_f32() { + let pli = Pipeline::<Dna, _>::dispatch(); + super::test_argmax_f32(&pli); + } + + #[test] + fn scanner_max() { + let generic = Pipeline::generic(); + + let seq = &SEQUENCE[..N]; + let encoded = EncodedSequence::<Dna>::encode(seq).unwrap(); + let mut striped = generic.stripe(encoded); + + let bg = Background::<Dna>::uniform(); + let cm = PATTERNS + .iter() + .map(EncodedSequence::encode) + .map(Result::unwrap) + .collect::<Result<CountMatrix<Dna>, _>>() + .unwrap(); + let pbm = cm.to_freq(0.1); + let pssm = pbm.to_scoring(bg); + + striped.configure(&pssm); + let scores: StripedScores<f32, _> = generic.score(&pssm, &striped); + let best = scores + .unstripe() + .iter() + .cloned() + .enumerate() + .max_by(|x, y| x.1.partial_cmp(&y.1).unwrap()) + .unwrap(); + + let m = Scanner::new(&pssm, &striped).max().unwrap(); + assert_eq!(m.position, best.0); + assert_eq!(m.score, best.1); + } +} + +#[cfg(target_feature = "sse2")] +mod sse2 { + use super::*; + + #[test] + fn argmax_f32() { + let pli = Pipeline::<Dna, _>::sse2().unwrap(); + super::test_argmax_f32::<U16, _>(&pli); + super::test_argmax_f32::<U32, _>(&pli); + } +} + +#[cfg(target_feature = "avx2")] +mod avx2 { + use super::*; + + #[test] + fn argmax_f32() { + let pli = Pipeline::<Dna, _>::avx2().unwrap(); + super::test_argmax_f32(&pli); + } +} + +#[cfg(target_feature = "neon")] +mod neon { + use super::*; + + #[test] + fn argmax_f32() { + let pli = Pipeline::<Dna, _>::neon().unwrap(); + super::test_argmax_f32::<U16, _>(&pli); + super::test_argmax_f32::<U32, _>(&pli); + } +} -- GitLab