From 5c25864829897ec2520800dfdea1cac13b8c0f78 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Mon, 24 Jun 2024 13:44:46 +0200
Subject: [PATCH] First check if block maximum is high enough  before running
 `Pipeline::threshold` in `Scanner`

---
 lightmotif/src/scan.rs | 70 +++++++++++++++++++++++++++++-------------
 1 file changed, 49 insertions(+), 21 deletions(-)

diff --git a/lightmotif/src/scan.rs b/lightmotif/src/scan.rs
index 726ed80..86054b3 100644
--- a/lightmotif/src/scan.rs
+++ b/lightmotif/src/scan.rs
@@ -91,6 +91,7 @@ impl Ord for Hit {
     }
 }
 
+/// A scanner for iterating over scoring matrix hits in a sequence.
 #[derive(Debug)]
 pub struct Scanner<'a, A: Alphabet, C: StrictlyPositive = DefaultColumns> {
     pssm: &'a ScoringMatrix<A>,
@@ -157,45 +158,72 @@ where
             // score the row slice
             self.pipeline
                 .score_rows_into(&self.dm, &self.seq, self.row..end, &mut self.dscores);
-            // scan through the positions above discrete threshold and recompute
-            // scores in floating-point to see if they pass the real threshold.
-            for c in self.pipeline.threshold(&self.dscores, t) {
-                let index = c.col * (self.seq.matrix().rows() - self.seq.wrap()) + self.row + c.row;
-                let score = self.pssm.score_position(&self.seq, index);
-                if score >= self.threshold {
-                    self.hits.push(Hit::new(index, score));
+            // check if any position is higher than the discrete threshold.
+            if self.pipeline.max(&self.dscores).unwrap() >= t {
+                // scan through the positions above discrete threshold and recompute
+                // scores in floating-point to see if they pass the real threshold.
+                for c in self.pipeline.threshold(&self.dscores, t) {
+                    let index =
+                        c.col * (self.seq.matrix().rows() - self.seq.wrap()) + self.row + c.row;
+                    let score = self.pssm.score_position(&self.seq, index);
+                    if score >= self.threshold {
+                        self.hits.push(Hit::new(index, score));
+                    }
                 }
             }
+            // Proceed to the next block.
             self.row += self.block_size;
         }
         self.hits.pop()
     }
 
     fn max(mut self) -> Option<Self::Item> {
+        // Compute the score of the best hit not yet returned, and translate
+        // the `f32` score threshold into a discrete, under-estimate `u8`
+        // threshold.
         let mut best = std::mem::take(&mut self.hits)
             .into_iter()
-            .max_by(|x, y| x.score.partial_cmp(&y.score).unwrap())
-            .unwrap_or(Hit::new(0, f32::MIN));
-        let mut best_discrete = self.dm.scale(best.score.max(self.threshold));
+            .filter(|hit| hit.score >= self.threshold)
+            .max_by(|x, y| x.score.partial_cmp(&y.score).unwrap());
+        let mut best_discrete = match &best {
+            Some(hit) => self.dm.scale(hit.score),
+            None => self.dm.scale(self.threshold),
+        };
+
+        // Cache th number of sequence rows in the striped sequence matrix.
+        let sequence_rows = self.seq.matrix().rows() - self.seq.wrap();
+
+        // Process all rows of the sequence and record the local
         while self.row < self.seq.matrix().rows() {
-            let end = (self.row + self.block_size).min(self.seq.matrix().rows() - self.seq.wrap());
+            // Score the rows of the current block.
+            let end = (self.row + self.block_size).min(sequence_rows);
             self.pipeline
                 .score_rows_into(&self.dm, &self.seq, self.row..end, &mut self.dscores);
-            if let Some(c) = self.pipeline.argmax(&self.dscores) {
-                let dscore = self.dscores.matrix()[c];
-                if dscore >= best_discrete {
-                    let index =
-                        c.col * (self.seq.matrix().rows() - self.seq.wrap()) + self.row + c.row;
-                    let score = self.pssm.score_position(&self.seq, index);
-                    if (score > best.score) | (score == best.score && index > best.position) {
-                        best = Hit::new(index, score);
-                        best_discrete = dscore;
+            // Check if the highest score in the block is high enough to be
+            // a new global maximum
+            if self.pipeline.max(&self.dscores).unwrap() >= best_discrete {
+                // Iterate over candidate position in `u8` scores and recalculate
+                // scores for candidates passing the threshold.
+                for c in self.pipeline.threshold(&self.dscores, best_discrete) {
+                    let dscore = self.dscores.matrix()[c];
+                    if dscore >= best_discrete {
+                        let index = c.col * sequence_rows + self.row + c.row;
+                        let score = self.pssm.score_position(&self.seq, index);
+                        if let Some(hit) = &best {
+                            if (score > hit.score) | (score == hit.score && index > hit.position) {
+                                best = Some(Hit::new(index, score));
+                                best_discrete = dscore;
+                            }
+                        } else {
+                            best = Some(Hit::new(index, score))
+                        }
                     }
                 }
             }
+            // Proceed to the next block.
             self.row += self.block_size;
         }
-        Some(best)
+        best
     }
 }
 
-- 
GitLab