From bbde5508117eb0ee39026993e4c372ed835ba470 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Thu, 20 Jun 2024 17:14:35 +0200
Subject: [PATCH] Reorganize tests in `lightmotif` crate

---
 lightmotif/tests/dna.rs    | 206 +++++++++++++++++++++++--------------
 lightmotif/tests/encode.rs |  95 +++++++++++++----
 2 files changed, 200 insertions(+), 101 deletions(-)

diff --git a/lightmotif/tests/dna.rs b/lightmotif/tests/dna.rs
index a6aaabf..989dec1 100644
--- a/lightmotif/tests/dna.rs
+++ b/lightmotif/tests/dna.rs
@@ -89,6 +89,35 @@ fn test_score<C: StrictlyPositive, P: Score<f32, Dna, C>>(pli: &P) {
     }
 }
 
+fn test_score_discrete<C: StrictlyPositive, P: Score<u8, Dna, C>>(pli: &P) {
+    let encoded = EncodedSequence::<Dna>::encode(SEQUENCE).unwrap();
+    let mut striped = Pipeline::generic().stripe(encoded);
+
+    let cm = CountMatrix::<Dna>::from_sequences(
+        PATTERNS.iter().map(|x| EncodedSequence::encode(x).unwrap()),
+    )
+    .unwrap();
+    let pbm = cm.to_freq(0.1);
+    let pwm = pbm.to_weight(None);
+    let pssm = pwm.to_scoring();
+    let dm = pssm.to_discrete();
+
+    striped.configure(&pssm);
+    let result = pli.score(&dm, &striped);
+    let scores = result.unstripe();
+
+    assert_eq!(scores.len(), EXPECTED.len());
+    for i in 0..scores.len() {
+        assert!(
+            dm.unscale(scores[i]) >= EXPECTED[i],
+            "{} != {} at position {}",
+            dm.unscale(scores[i]),
+            EXPECTED[i],
+            i
+        );
+    }
+}
+
 fn test_argmax<C: StrictlyPositive, P: Score<f32, Dna, C> + Maximum<f32, C>>(pli: &P) {
     let encoded = EncodedSequence::<Dna>::encode(SEQUENCE).unwrap();
     let mut striped = Pipeline::generic().stripe(encoded);
@@ -197,104 +226,121 @@ mod generic {
     }
 }
 
-#[test]
-fn test_score_dispatch() {
-    let pli = Pipeline::dispatch();
-    test_score(&pli);
-}
+mod dispatch {
+    use super::*;
 
-#[test]
-fn test_argmax_dispatch() {
-    let pli = Pipeline::dispatch();
-    test_argmax(&pli);
-}
+    #[test]
+    fn test_score() {
+        let pli = Pipeline::dispatch();
+        super::test_score::<U32, _>(&pli);
+    }
 
-#[test]
-fn test_threshold_dispatch() {
-    let pli = Pipeline::dispatch();
-    test_threshold(&pli);
-}
+    #[test]
+    fn test_score_rows() {
+        let pli = Pipeline::dispatch();
+        super::test_score_rows::<U32, _>(&pli);
+    }
 
-#[cfg(target_feature = "sse2")]
-#[test]
-fn test_score_sse2() {
-    let pli = Pipeline::sse2().unwrap();
-    test_score::<U16, _>(&pli);
-}
+    #[test]
+    fn test_argmax() {
+        let pli = Pipeline::dispatch();
+        super::test_argmax::<U32, _>(&pli);
+    }
 
-#[cfg(target_feature = "sse2")]
-#[test]
-fn test_argmax_sse2() {
-    let pli = Pipeline::sse2().unwrap();
-    test_argmax::<U16, _>(&pli);
+    #[test]
+    fn test_threshold() {
+        let pli = Pipeline::dispatch();
+        super::test_threshold::<U32, _>(&pli);
+    }
 }
 
 #[cfg(target_feature = "sse2")]
-#[test]
-fn test_threshold_sse2() {
-    let pli = Pipeline::sse2().unwrap();
-    test_threshold::<U16, _>(&pli);
-}
+mod sse2 {
+    use super::*;
 
-#[cfg(target_feature = "sse2")]
-#[test]
-fn test_score_sse2_32() {
-    let pli = Pipeline::sse2().unwrap();
-    test_score::<U32, _>(&pli);
-}
+    #[test]
+    fn test_score() {
+        let pli = Pipeline::sse2().unwrap();
+        super::test_score::<U32, _>(&pli);
+        super::test_score::<U16, _>(&pli);
+    }
 
-#[cfg(target_feature = "sse2")]
-#[test]
-fn test_argmax_sse2_32() {
-    let pli = Pipeline::sse2().unwrap();
-    test_argmax::<U32, _>(&pli);
-}
+    #[test]
+    fn test_score_rows() {
+        let pli = Pipeline::sse2().unwrap();
+        super::test_score_rows::<U32, _>(&pli);
+        super::test_score_rows::<U16, _>(&pli);
+    }
 
-#[cfg(target_feature = "sse2")]
-#[test]
-fn test_threshold_sse2_32() {
-    let pli = Pipeline::sse2().unwrap();
-    test_threshold::<U32, _>(&pli);
-}
+    #[test]
+    fn test_argmax() {
+        let pli = Pipeline::sse2().unwrap();
+        super::test_argmax::<U32, _>(&pli);
+        super::test_argmax::<U16, _>(&pli);
+    }
 
-#[cfg(target_feature = "avx2")]
-#[test]
-fn test_score_avx2() {
-    let pli = Pipeline::avx2().unwrap();
-    test_score(&pli);
+    #[test]
+    fn test_threshold() {
+        let pli = Pipeline::sse2().unwrap();
+        super::test_threshold::<U32, _>(&pli);
+        super::test_threshold::<U16, _>(&pli);
+    }
 }
 
 #[cfg(target_feature = "avx2")]
-#[test]
-fn test_argmax_avx2() {
-    let pli = Pipeline::avx2().unwrap();
-    test_argmax(&pli);
-}
+mod avx2 {
+    use super::*;
 
-#[cfg(target_feature = "avx2")]
-#[test]
-fn test_threshold_avx2() {
-    let pli = Pipeline::avx2().unwrap();
-    test_threshold::<U32, _>(&pli);
-}
+    #[test]
+    fn test_score() {
+        let pli = Pipeline::avx2().unwrap();
+        super::test_score::<U32, _>(&pli);
+    }
 
-#[cfg(target_feature = "neon")]
-#[test]
-fn test_score_neon() {
-    let pli = Pipeline::neon().unwrap();
-    test_score::<U16, _>(&pli);
-}
+    #[test]
+    fn test_score_rows() {
+        let pli = Pipeline::avx2().unwrap();
+        super::test_score_rows::<U32, _>(&pli);
+    }
 
-#[cfg(target_feature = "neon")]
-#[test]
-fn test_argmax_neon() {
-    let pli = Pipeline::neon().unwrap();
-    test_argmax::<U16, _>(&pli);
+    #[test]
+    fn test_argmax() {
+        let pli = Pipeline::avx2().unwrap();
+        super::test_argmax::<U32, _>(&pli);
+    }
+
+    #[test]
+    fn test_threshold() {
+        let pli = Pipeline::avx2().unwrap();
+        super::test_threshold::<U32, _>(&pli);
+    }
 }
 
 #[cfg(target_feature = "neon")]
-#[test]
-fn test_threshold_neon() {
-    let pli = Pipeline::neon().unwrap();
-    test_threshold::<U16, _>(&pli);
+mod neon {
+    use super::*;
+
+    #[test]
+    fn test_score() {
+        let pli = Pipeline::neon().unwrap();
+        super::test_score::<U16, _>(&pli);
+    }
+
+    #[test]
+    fn test_score_rows() {
+        let pli = Pipeline::neon().unwrap();
+        super::test_score_rows::<U16, _>(&pli);
+    }
+
+    #[test]
+    fn test_argmax() {
+        let pli = Pipeline::neon().unwrap();
+        super::test_argmax::<U16, _>(&pli);
+    }
+
+    #[test]
+    fn test_threshold() {
+        let pli = Pipeline::neon().unwrap();
+        super::test_threshold::<U16, _>(&pli);
+    }
 }
diff --git a/lightmotif/tests/encode.rs b/lightmotif/tests/encode.rs
index 0356bcc..8376df4 100644
--- a/lightmotif/tests/encode.rs
+++ b/lightmotif/tests/encode.rs
@@ -15,42 +15,95 @@ const EXPECTED: &[Nucleotide] = &[
     A, T, G, T, C, C, C, A, A, C, A, A, C, G, A, T, A, C, C, C, C, G, A, G, C, C, C, A, T, C, G, C, C, G, T, C, A, T, C, G, G, C, T, C, G, G, C, A, T, G, C, A, G, A, T, T, C, C, C, A, G, G, C, G 
 ];
 
-fn test_encode<P: Encode<Dna>>(pli: &P) {
+fn test_encode_sequence<P: Encode<Dna>>(pli: &P) {
     let encoded = pli.encode(SEQUENCE).unwrap();
     assert_eq!(encoded, EXPECTED);
+}
+
+fn test_encode_unknown<P: Encode<Dna>>(pli: &P) {
     let err = pli.encode(UNKNOWNS).unwrap_err();
     assert_eq!(err.0, '.');
 }
 
-#[test]
-fn test_encode_generic() {
-    let pli = Pipeline::generic();
-    test_encode(&pli);
+mod generic {
+    use super::*;
+
+    #[test]
+    fn test_sequence() {
+        let pli = Pipeline::generic();
+        test_encode_sequence(&pli);
+    }
+
+    #[test]
+    fn test_unknown() {
+        let pli = Pipeline::generic();
+        test_encode_unknown(&pli);
+    }
 }
 
-#[test]
-fn test_encode_dispatch() {
-    let pli = Pipeline::dispatch();
-    test_encode(&pli);
+mod dispatch {
+    use super::*;
+
+    #[test]
+    fn test_sequence() {
+        let pli = Pipeline::dispatch();
+        test_encode_sequence(&pli);
+    }
+
+    #[test]
+    fn test_unknown() {
+        let pli = Pipeline::dispatch();
+        test_encode_unknown(&pli);
+    }
 }
 
 #[cfg(target_feature = "sse2")]
-#[test]
-fn test_encode_sse2() {
-    let pli = Pipeline::sse2().unwrap();
-    test_encode(&pli);
+mod sse2 {
+    use super::*;
+
+    #[test]
+    fn test_sequence() {
+        let pli = Pipeline::sse2().unwrap();
+        test_encode_sequence(&pli);
+    }
+
+    #[test]
+    fn test_unknown() {
+        let pli = Pipeline::sse2().unwrap();
+        test_encode_unknown(&pli);
+    }
 }
 
 #[cfg(target_feature = "avx2")]
-#[test]
-fn test_encode_avx2() {
-    let pli = Pipeline::avx2().unwrap();
-    test_encode(&pli);
+mod avx2 {
+    use super::*;
+
+    #[test]
+    fn test_sequence() {
+        let pli = Pipeline::avx2().unwrap();
+        test_encode_sequence(&pli);
+    }
+
+    #[test]
+    fn test_unknown() {
+        let pli = Pipeline::avx2().unwrap();
+        test_encode_unknown(&pli);
+    }
 }
 
 #[cfg(target_feature = "neon")]
-#[test]
-fn test_encode_neon() {
-    let pli = Pipeline::neon().unwrap();
-    test_encode(&pli);
+mod neon {
+    use super::*;
+
+    #[test]
+    fn test_sequence() {
+        let pli = Pipeline::neon().unwrap();
+        test_encode_sequence(&pli);
+    }
+
+    #[test]
+    fn test_unknown() {
+        let pli = Pipeline::neon().unwrap();
+        test_encode_unknown(&pli);
+    }
 }
-- 
GitLab