From ce993e6d8471227d1a24ec1435f3b23736ba42d4 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Fri, 14 Jun 2024 23:41:47 +0200
Subject: [PATCH] Fix issue with scoring of an empty range of rows

---
 lightmotif/src/dense.rs             |  3 +++
 lightmotif/src/pli/mod.rs           | 39 +++++++++++++++++++++++++++--
 lightmotif/src/pli/platform/avx2.rs | 12 +++++----
 lightmotif/src/pli/platform/neon.rs |  4 +--
 lightmotif/src/pli/platform/sse2.rs |  4 +--
 lightmotif/src/scan.rs              |  5 +++-
 lightmotif/src/seq.rs               |  2 ++
 7 files changed, 57 insertions(+), 12 deletions(-)

diff --git a/lightmotif/src/dense.rs b/lightmotif/src/dense.rs
index c3ab0bf..4fe699c 100644
--- a/lightmotif/src/dense.rs
+++ b/lightmotif/src/dense.rs
@@ -216,6 +216,7 @@ impl<T: Default + Copy, C: Unsigned, A: Unsigned + PowerOfTwo> Index<usize>
     fn index(&self, index: usize) -> &Self::Output {
         let c = self.stride();
         let row = self.offset + c * index;
+        debug_assert!(row + C::USIZE <= self.data.len());
         &self.data[row..row + C::USIZE]
     }
 }
@@ -227,6 +228,7 @@ impl<T: Default + Copy, C: Unsigned, A: Unsigned + PowerOfTwo> IndexMut<usize>
     fn index_mut(&mut self, index: usize) -> &mut Self::Output {
         let c = self.stride();
         let row = self.offset + c * index;
+        debug_assert!(row + C::USIZE <= self.data.len());
         &mut self.data[row..row + C::USIZE]
     }
 }
@@ -239,6 +241,7 @@ impl<T: Default + Copy, C: Unsigned, A: Unsigned + PowerOfTwo> Index<MatrixCoord
     fn index(&self, index: MatrixCoordinates) -> &Self::Output {
         let c = self.stride();
         let i = self.offset + c * index.row + index.col;
+        debug_assert!(i < self.data.len());
         &self.data[i]
     }
 }
diff --git a/lightmotif/src/pli/mod.rs b/lightmotif/src/pli/mod.rs
index 40685b2..297c60a 100644
--- a/lightmotif/src/pli/mod.rs
+++ b/lightmotif/src/pli/mod.rs
@@ -81,13 +81,13 @@ pub trait Score<A: Alphabet, C: StrictlyPositive> {
         let seq = seq.as_ref();
         let pssm = pssm.as_ref();
 
-        if seq.len() < pssm.len() {
+        if seq.len() < pssm.len() || rows.len() == 0 {
             scores.resize(0, 0);
             return;
         }
 
         // FIXME?
-        scores.resize(rows.len(), seq.len() - pssm.len() + 1);
+        scores.resize(rows.len(), (seq.len() + 1).saturating_sub(pssm.len()));
 
         let result = scores.matrix_mut();
         let matrix = pssm.matrix();
@@ -471,3 +471,38 @@ where
     C: StrictlyPositive + MultipleOf<U16>,
 {
 }
+
+// -- Tests --------------------------------------------------------------------
+
+#[cfg(test)]
+mod test {
+    use std::str::FromStr;
+    use typenum::consts::U4;
+
+    use super::*;
+
+    use crate::abc::Dna;
+    use crate::pwm::CountMatrix;
+
+    #[test]
+    fn test_score_rows_into_empty() {
+        let pli = Pipeline::generic();
+
+        let seq = EncodedSequence::<Dna>::from_str("ATGCA").unwrap();
+        let mut striped = <Pipeline<_, _> as Stripe<Dna, U4>>::stripe(&pli, seq);
+
+        let cm = CountMatrix::<Dna>::from_sequences(
+            ["ATTA", "ATTC"]
+                .iter()
+                .map(|x| EncodedSequence::encode(x).unwrap()),
+        )
+        .unwrap();
+        let pbm = cm.to_freq(0.1);
+        let pwm = pbm.to_weight(None);
+        let pssm = pwm.to_scoring();
+
+        striped.configure(&pssm);
+        let mut scores = StripedScores::empty();
+        pli.score_rows_into(pssm, striped, 1..1, &mut scores);
+    }
+}
diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs
index b8f262c..367c502 100644
--- a/lightmotif/src/pli/platform/avx2.rs
+++ b/lightmotif/src/pli/platform/avx2.rs
@@ -105,6 +105,8 @@ unsafe fn score_avx2_permute<A>(
     <<A as Alphabet>::K as IsLessOrEqual<U5>>::Output: NonZero,
 {
     let data = scores.matrix_mut();
+    debug_assert!(data.rows() > 0);
+
     let mut rowptr = data[0].as_mut_ptr();
     // constant vector for comparing unknown bases
     let n = _mm256_set1_epi32(<A as Alphabet>::K::I32 - 1);
@@ -406,7 +408,7 @@ unsafe fn stripe_avx2<A>(
     let mut matrix = std::mem::take(striped).into_matrix();
     matrix.resize(src_stride);
 
-    // Early exit if sequence is too empty (no allocated matrix).
+    // Early exit if sequence is empty (no allocated matrix).
     if length == 0 {
         return;
     }
@@ -637,12 +639,12 @@ impl Avx2 {
             );
         }
 
-        if seq.len() < pssm.len() {
+        if seq.len() < pssm.len() || rows.len() == 0 {
             scores.resize(0, 0);
             return;
         }
 
-        scores.resize(rows.len(), seq.len() - pssm.len() + 1);
+        scores.resize(rows.len(), (seq.len() + 1).saturating_sub(pssm.len()));
         #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
         unsafe {
             score_avx2_permute(pssm, seq, rows, scores)
@@ -672,12 +674,12 @@ impl Avx2 {
             );
         }
 
-        if seq.len() < pssm.len() {
+        if seq.len() < pssm.len() || rows.len() == 0 {
             scores.resize(0, 0);
             return;
         }
 
-        scores.resize(rows.len(), seq.len() - pssm.len() + 1);
+        scores.resize(rows.len(), (seq.len() + 1).saturating_sub(pssm.len()));
         #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
         unsafe {
             score_avx2_gather(pssm, seq, rows, scores)
diff --git a/lightmotif/src/pli/platform/neon.rs b/lightmotif/src/pli/platform/neon.rs
index c7751a1..bc1eb3a 100644
--- a/lightmotif/src/pli/platform/neon.rs
+++ b/lightmotif/src/pli/platform/neon.rs
@@ -215,12 +215,12 @@ impl Neon {
             );
         }
 
-        if seq.len() < pssm.len() {
+        if seq.len() < pssm.len() || rows.len() == 0 {
             scores.resize(0, 0);
             return;
         }
 
-        scores.resize(rows.len(), seq.len() - pssm.len() + 1);
+        scores.resize(rows.len(), (seq.len() + 1).saturating_sub(pssm.len()));
         #[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
         unsafe {
             score_neon(pssm, seq, rows, scores);
diff --git a/lightmotif/src/pli/platform/sse2.rs b/lightmotif/src/pli/platform/sse2.rs
index aca5c3e..92339d5 100644
--- a/lightmotif/src/pli/platform/sse2.rs
+++ b/lightmotif/src/pli/platform/sse2.rs
@@ -205,12 +205,12 @@ impl Sse2 {
             );
         }
 
-        if seq.len() < pssm.len() {
+        if seq.len() < pssm.len() || rows.len() == 0 {
             scores.resize(0, 0);
             return;
         }
 
-        scores.resize(rows.len(), seq.len() - pssm.len() + 1);
+        scores.resize(rows.len(), (seq.len() + 1).saturating_sub(pssm.len()));
         #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
         unsafe {
             score_sse2(pssm, seq, rows, scores);
diff --git a/lightmotif/src/scan.rs b/lightmotif/src/scan.rs
index adb2fe3..ce4444e 100644
--- a/lightmotif/src/scan.rs
+++ b/lightmotif/src/scan.rs
@@ -4,6 +4,8 @@ use super::pli::dispatch::Dispatch;
 use super::pli::platform::Backend;
 use super::pwm::ScoringMatrix;
 use super::seq::StripedSequence;
+use crate::dense::DenseMatrix;
+use crate::num::Unsigned;
 use crate::pli::Maximum;
 use crate::pli::Pipeline;
 use crate::pli::Score;
@@ -141,7 +143,8 @@ where
     fn next(&mut self) -> Option<Self::Item> {
         while self.hits.is_empty() && self.row < self.seq.matrix().rows() {
             let pli = Pipeline::dispatch();
-            let end = (self.row + self.block_size).min(self.seq.matrix().rows() - self.seq.wrap());
+            let end = (self.row + self.block_size)
+                .min(self.seq.matrix().rows().saturating_sub(self.seq.wrap()));
             pli.score_rows_into(&self.pssm, &self.seq, self.row..end, &mut self.scores);
             let matrix = self.scores.matrix();
             for c in pli.threshold(&self.scores, self.threshold) {
diff --git a/lightmotif/src/seq.rs b/lightmotif/src/seq.rs
index 023a1d1..1880de4 100644
--- a/lightmotif/src/seq.rs
+++ b/lightmotif/src/seq.rs
@@ -337,6 +337,8 @@ impl<A: Alphabet, C: StrictlyPositive> SymbolCount<A> for &StripedSequence<A, C>
     }
 }
 
+// -- Tests --------------------------------------------------------------------
+
 #[cfg(test)]
 mod test {
     use typenum::consts::U2;
-- 
GitLab