From cd75a9053ba41fbeb953df03944effd309fabf7f Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Fri, 21 Jun 2024 00:55:51 +0200 Subject: [PATCH] Add unit tests for `lightmotif::scan::Scanner` --- lightmotif/src/scan.rs | 67 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/lightmotif/src/scan.rs b/lightmotif/src/scan.rs index 6a5428b..07ce441 100644 --- a/lightmotif/src/scan.rs +++ b/lightmotif/src/scan.rs @@ -181,7 +181,7 @@ where .into_iter() .max_by(|x, y| x.score.partial_cmp(&y.score).unwrap()) .unwrap_or(Hit::new(0, f32::MIN)); - let mut best_discrete = self.dm.scale(best.score); + let mut best_discrete = self.dm.scale(best.score.max(self.threshold)); while self.row < self.seq.matrix().rows() { let end = (self.row + self.block_size).min(self.seq.matrix().rows() - self.seq.wrap()); self.pipeline @@ -203,3 +203,68 @@ where Some(best) } } + +#[cfg(test)] +mod test { + use super::*; + use crate::abc::Dna; + use crate::pli::Stripe; + use crate::pwm::CountMatrix; + use crate::pwm::ScoringMatrix; + use crate::seq::EncodedSequence; + + const SEQUENCE: &str = "ATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCG"; + const PATTERNS: &[&str] = &["GTTGACCTTATCAAC", "GTTGATCCAGTCAAC"]; + + fn seq<C: StrictlyPositive>() -> StripedSequence<Dna, C> { + let encoded = EncodedSequence::<Dna>::encode(SEQUENCE).unwrap(); + Pipeline::generic().stripe(encoded) + } + + fn pssm() -> ScoringMatrix<Dna> { + let cm = CountMatrix::<Dna>::from_sequences( + PATTERNS.iter().map(|x| EncodedSequence::encode(x).unwrap()), + ) + .unwrap(); + let pbm = cm.to_freq(0.1); + let pwm = pbm.to_weight(None); + pwm.to_scoring() + } + + #[test] + fn test_collect() { + let pssm = self::pssm(); + let mut striped = self::seq(); + striped.configure(&pssm); + + let mut scanner = Scanner::new(&pssm, &striped); + scanner.threshold(-10.0); + + let mut hits = scanner.collect::<Vec<_>>(); + assert_eq!(hits.len(), 3); + + hits.sort_by_key(|hit| hit.position); + assert_eq!(hits[0].position, 18); + assert_eq!(hits[1].position, 27); + assert_eq!(hits[2].position, 32); + } + + #[test] + fn test_max() { + let pssm = self::pssm(); + let mut striped = self::seq(); + striped.configure(&pssm); + + let mut scanner = Scanner::new(&pssm, &striped); + scanner.threshold(-10.0); + + let hit = scanner.max().unwrap(); + assert!( + (hit.score - -5.50167).abs() < 1e-5, + "{} != {}", + hit.score, + -5.50167 + ); + assert_eq!(hit.position, 18); + } +} -- GitLab