From 41a6d1f729a217e05913b0de52fce605f30bda99 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Fri, 21 Jun 2024 16:52:48 +0200
Subject: [PATCH] Add tests to ensure `Maximum::argmax` works as expected for
 all implementations

---
 lightmotif/src/pli/mod.rs  |   4 +-
 lightmotif/tests/argmax.rs | 142 +++++++++++++++++++++++++++++++++++++
 2 files changed, 145 insertions(+), 1 deletion(-)
 create mode 100644 lightmotif/tests/argmax.rs

diff --git a/lightmotif/src/pli/mod.rs b/lightmotif/src/pli/mod.rs
index 1a5f57f..80a053b 100644
--- a/lightmotif/src/pli/mod.rs
+++ b/lightmotif/src/pli/mod.rs
@@ -3,6 +3,8 @@
 use std::ops::AddAssign;
 use std::ops::Range;
 
+use generic_array::ArrayLength;
+
 use crate::abc::Alphabet;
 use crate::abc::Dna;
 use crate::abc::Protein;
@@ -361,7 +363,7 @@ where
 impl<A, C> Maximum<f32, C> for Pipeline<A, Sse2>
 where
     A: Alphabet,
-    C: StrictlyPositive + MultipleOf<U16>,
+    C: StrictlyPositive + MultipleOf<U16> + ArrayLength,
 {
     #[inline]
     fn argmax(&self, scores: &StripedScores<f32, C>) -> Option<MatrixCoordinates> {
diff --git a/lightmotif/tests/argmax.rs b/lightmotif/tests/argmax.rs
new file mode 100644
index 0000000..fdab04e
--- /dev/null
+++ b/lightmotif/tests/argmax.rs
@@ -0,0 +1,142 @@
+extern crate lightmotif;
+
+use lightmotif::abc::Background;
+use lightmotif::abc::Dna;
+use lightmotif::num::StrictlyPositive;
+use lightmotif::num::U1;
+use lightmotif::num::U16;
+use lightmotif::num::U32;
+use lightmotif::pli::Maximum;
+use lightmotif::pli::Pipeline;
+use lightmotif::pli::Score;
+use lightmotif::pli::Stripe;
+use lightmotif::pwm::CountMatrix;
+use lightmotif::scan::Scanner;
+use lightmotif::scores::StripedScores;
+use lightmotif::seq::EncodedSequence;
+
+const SEQUENCE: &str = include_str!("../benches/ecoli.txt");
+const PATTERNS: &[&str] = &["GTTGACCTTATCAAC", "GTTGATCCAGTCAAC"];
+const N: usize = SEQUENCE.len() / 10;
+
+fn test_argmax_f32<C: StrictlyPositive, P: Maximum<f32, C>>(pli: &P) {
+    let generic = Pipeline::generic();
+
+    let seq = &SEQUENCE[..N];
+    let encoded = EncodedSequence::<Dna>::encode(seq).unwrap();
+    let mut striped = generic.stripe(encoded);
+
+    let bg = Background::<Dna>::uniform();
+    let cm = PATTERNS
+        .iter()
+        .map(EncodedSequence::encode)
+        .map(Result::unwrap)
+        .collect::<Result<CountMatrix<Dna>, _>>()
+        .unwrap();
+    let pbm = cm.to_freq(0.1);
+    let pssm = pbm.to_scoring(bg);
+
+    striped.configure(&pssm);
+    let scores: StripedScores<f32, C> = generic.score(pssm, striped);
+
+    let best = scores
+        .unstripe()
+        .iter()
+        .cloned()
+        .enumerate()
+        .max_by(|x, y| x.1.partial_cmp(&y.1).unwrap())
+        .unwrap();
+    let m = pli.argmax(&scores).unwrap();
+    assert_eq!(scores.offset(m), best.0);
+    assert_eq!(scores.matrix()[m], best.1);
+}
+
+mod generic {
+    use super::*;
+
+    #[test]
+    fn argmax_f32() {
+        let pli = Pipeline::<Dna, _>::generic();
+        super::test_argmax_f32::<U1, _>(&pli);
+        super::test_argmax_f32::<U16, _>(&pli);
+        super::test_argmax_f32::<U32, _>(&pli);
+    }
+}
+
+mod dispatch {
+    use super::*;
+
+    #[test]
+    fn argmax_f32() {
+        let pli = Pipeline::<Dna, _>::dispatch();
+        super::test_argmax_f32(&pli);
+    }
+
+    #[test]
+    fn scanner_max() {
+        let generic = Pipeline::generic();
+
+        let seq = &SEQUENCE[..N];
+        let encoded = EncodedSequence::<Dna>::encode(seq).unwrap();
+        let mut striped = generic.stripe(encoded);
+
+        let bg = Background::<Dna>::uniform();
+        let cm = PATTERNS
+            .iter()
+            .map(EncodedSequence::encode)
+            .map(Result::unwrap)
+            .collect::<Result<CountMatrix<Dna>, _>>()
+            .unwrap();
+        let pbm = cm.to_freq(0.1);
+        let pssm = pbm.to_scoring(bg);
+
+        striped.configure(&pssm);
+        let scores: StripedScores<f32, _> = generic.score(&pssm, &striped);
+        let best = scores
+            .unstripe()
+            .iter()
+            .cloned()
+            .enumerate()
+            .max_by(|x, y| x.1.partial_cmp(&y.1).unwrap())
+            .unwrap();
+
+        let m = Scanner::new(&pssm, &striped).max().unwrap();
+        assert_eq!(m.position, best.0);
+        assert_eq!(m.score, best.1);
+    }
+}
+
+#[cfg(target_feature = "sse2")]
+mod sse2 {
+    use super::*;
+
+    #[test]
+    fn argmax_f32() {
+        let pli = Pipeline::<Dna, _>::sse2().unwrap();
+        super::test_argmax_f32::<U16, _>(&pli);
+        super::test_argmax_f32::<U32, _>(&pli);
+    }
+}
+
+#[cfg(target_feature = "avx2")]
+mod avx2 {
+    use super::*;
+
+    #[test]
+    fn argmax_f32() {
+        let pli = Pipeline::<Dna, _>::avx2().unwrap();
+        super::test_argmax_f32(&pli);
+    }
+}
+
+#[cfg(target_feature = "neon")]
+mod neon {
+    use super::*;
+
+    #[test]
+    fn argmax_f32() {
+        let pli = Pipeline::<Dna, _>::neon().unwrap();
+        super::test_argmax_f32::<U16, _>(&pli);
+        super::test_argmax_f32::<U32, _>(&pli);
+    }
+}
-- 
GitLab