From cd75a9053ba41fbeb953df03944effd309fabf7f Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Fri, 21 Jun 2024 00:55:51 +0200
Subject: [PATCH] Add unit tests for `lightmotif::scan::Scanner`

---
 lightmotif/src/scan.rs | 67 +++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 66 insertions(+), 1 deletion(-)

diff --git a/lightmotif/src/scan.rs b/lightmotif/src/scan.rs
index 6a5428b..07ce441 100644
--- a/lightmotif/src/scan.rs
+++ b/lightmotif/src/scan.rs
@@ -181,7 +181,7 @@ where
             .into_iter()
             .max_by(|x, y| x.score.partial_cmp(&y.score).unwrap())
             .unwrap_or(Hit::new(0, f32::MIN));
-        let mut best_discrete = self.dm.scale(best.score);
+        let mut best_discrete = self.dm.scale(best.score.max(self.threshold));
         while self.row < self.seq.matrix().rows() {
             let end = (self.row + self.block_size).min(self.seq.matrix().rows() - self.seq.wrap());
             self.pipeline
@@ -203,3 +203,68 @@ where
         Some(best)
     }
 }
+
+#[cfg(test)]
+mod test {
+    use super::*;
+    use crate::abc::Dna;
+    use crate::pli::Stripe;
+    use crate::pwm::CountMatrix;
+    use crate::pwm::ScoringMatrix;
+    use crate::seq::EncodedSequence;
+
+    const SEQUENCE: &str = "ATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCG";
+    const PATTERNS: &[&str] = &["GTTGACCTTATCAAC", "GTTGATCCAGTCAAC"];
+
+    fn seq<C: StrictlyPositive>() -> StripedSequence<Dna, C> {
+        let encoded = EncodedSequence::<Dna>::encode(SEQUENCE).unwrap();
+        Pipeline::generic().stripe(encoded)
+    }
+
+    fn pssm() -> ScoringMatrix<Dna> {
+        let cm = CountMatrix::<Dna>::from_sequences(
+            PATTERNS.iter().map(|x| EncodedSequence::encode(x).unwrap()),
+        )
+        .unwrap();
+        let pbm = cm.to_freq(0.1);
+        let pwm = pbm.to_weight(None);
+        pwm.to_scoring()
+    }
+
+    #[test]
+    fn test_collect() {
+        let pssm = self::pssm();
+        let mut striped = self::seq();
+        striped.configure(&pssm);
+
+        let mut scanner = Scanner::new(&pssm, &striped);
+        scanner.threshold(-10.0);
+
+        let mut hits = scanner.collect::<Vec<_>>();
+        assert_eq!(hits.len(), 3);
+
+        hits.sort_by_key(|hit| hit.position);
+        assert_eq!(hits[0].position, 18);
+        assert_eq!(hits[1].position, 27);
+        assert_eq!(hits[2].position, 32);
+    }
+
+    #[test]
+    fn test_max() {
+        let pssm = self::pssm();
+        let mut striped = self::seq();
+        striped.configure(&pssm);
+
+        let mut scanner = Scanner::new(&pssm, &striped);
+        scanner.threshold(-10.0);
+
+        let hit = scanner.max().unwrap();
+        assert!(
+            (hit.score - -5.50167).abs() < 1e-5,
+            "{} != {}",
+            hit.score,
+            -5.50167
+        );
+        assert_eq!(hit.position, 18);
+    }
+}
-- 
GitLab