From 81301e877143c1b08933a0187d84f147a6d0e4dd Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Tue, 25 Jun 2024 01:53:50 +0200
Subject: [PATCH] Cache Q-values buffer in `TfmPvalue` structure

---
 lightmotif-tfmpvalue/src/lib.rs | 66 +++++++++++++++++----------------
 1 file changed, 35 insertions(+), 31 deletions(-)

diff --git a/lightmotif-tfmpvalue/src/lib.rs b/lightmotif-tfmpvalue/src/lib.rs
index ad469f7..f6e0aaa 100644
--- a/lightmotif-tfmpvalue/src/lib.rs
+++ b/lightmotif-tfmpvalue/src/lib.rs
@@ -34,6 +34,8 @@ pub struct TfmPvalue<A: Alphabet, M: AsRef<ScoringMatrix<A>>> {
     max_score_rows: Vec<i64>,
     /// The minimum integer score reachable at each row of the matrix.
     min_score_rows: Vec<i64>,
+    /// The Q-values for the current granularity
+    qvalues: Vec<IntMap<f64>>,
 }
 
 #[allow(non_snake_case)]
@@ -49,16 +51,8 @@ impl<A: Alphabet, M: AsRef<ScoringMatrix<A>>> TfmPvalue<A, M> {
         let range = (0..M)
             .map(|i| {
                 let row = &m[i][..A::K::USIZE - 1];
-                let max_score = row
-                    .iter()
-                    .max_by(|x, y| x.partial_cmp(y).unwrap())
-                    .cloned()
-                    .unwrap_or_default();
-                let min_score = row
-                    .iter()
-                    .min_by(|x, y| x.partial_cmp(y).unwrap())
-                    .cloned()
-                    .unwrap_or_default();
+                let max_score = row.iter().cloned().reduce(f32::max).unwrap_or_default();
+                let min_score = row.iter().cloned().reduce(f32::min).unwrap_or_default();
                 max_score - min_score
             })
             .collect::<Vec<_>>();
@@ -73,6 +67,7 @@ impl<A: Alphabet, M: AsRef<ScoringMatrix<A>>> TfmPvalue<A, M> {
             int_matrix: DenseMatrix::new(M),
             max_score_rows: vec![0; M],
             min_score_rows: vec![0; M],
+            qvalues: vec![IntMap::default(); M + 1],
             error_max: 0.0,
         }
     }
@@ -132,7 +127,15 @@ impl<A: Alphabet, M: AsRef<ScoringMatrix<A>>> TfmPvalue<A, M> {
     }
 
     /// Compute the score distribution between `min` and `max`.
-    fn distribution(&self, min: i64, max: i64) -> Vec<IntMap<f64>> {
+    ///
+    /// The resulting distributions is stored in `self.qvalues`.
+    fn distribution(&mut self, min: i64, max: i64) {
+        // Clear Q-values
+        for map in self.qvalues.iter_mut() {
+            map.clear();
+        }
+
+        //
         let matrix = self.matrix.as_ref();
         let M: usize = matrix.len();
         let K: usize = <A as Alphabet>::K::USIZE;
@@ -140,9 +143,6 @@ impl<A: Alphabet, M: AsRef<ScoringMatrix<A>>> TfmPvalue<A, M> {
         // background frequencies
         let bg = matrix.background().frequencies();
 
-        // maps for each steps of the computation
-        let mut qvalues = vec![IntMap::default(); M + 1];
-
         // maximum score reachable with the suffix matrix from i to M-1
         let mut maxs = vec![0; M + 1];
         for i in (0..M).rev() {
@@ -152,22 +152,24 @@ impl<A: Alphabet, M: AsRef<ScoringMatrix<A>>> TfmPvalue<A, M> {
         // initialize the map at first position with background frequencies
         for k in 0..K - 1 {
             if self.int_matrix[0][k] + maxs[1] >= min {
-                *qvalues[0].entry(self.int_matrix[0][k]).or_default() += bg[k] as f64;
+                *self.qvalues[0].entry(self.int_matrix[0][k]).or_default() += bg[k] as f64;
             }
         }
 
         // compute q values for scores greater or equal to min
-        qvalues[M - 1].insert(max + 1, 0.0);
+        self.qvalues[M - 1].insert(max + 1, 0.0);
         for pos in 1..M {
+            // get the matrix row at the current position
+            let int_row = &self.int_matrix[pos];
             // split the array in two to make the borrow checker happy
-            let (l, r) = qvalues.split_at_mut(pos);
+            let (l, r) = self.qvalues.split_at_mut(pos);
             // iterate on every reachable score at the current position
-            for key in l[pos - 1].keys() {
+            for (key, val) in &l[pos - 1] {
                 for k in 0..K - 1 {
-                    let sc = key + self.int_matrix[pos][k];
+                    let sc = key + int_row[k];
                     if sc + maxs[pos + 1] >= min {
                         // the score min can be reached
-                        let occ = l[pos - 1][key] * bg[k] as f64;
+                        let occ = val * bg[k] as f64;
                         if sc > max {
                             // the score will be greater than max for all suffixes
                             *r[M - 1 - pos].entry(max + 1).or_default() += occ;
@@ -178,12 +180,10 @@ impl<A: Alphabet, M: AsRef<ScoringMatrix<A>>> TfmPvalue<A, M> {
                 }
             }
         }
-
-        qvalues
     }
 
     /// Search the p-value range for the given score.
-    fn lookup_pvalue(&self, score: f64) -> RangeInclusive<f64> {
+    fn lookup_pvalue(&mut self, score: f64) -> RangeInclusive<f64> {
         assert!(!self.granularity.is_nan());
         let matrix = self.matrix.as_ref();
         let M: usize = matrix.len();
@@ -195,16 +195,16 @@ impl<A: Alphabet, M: AsRef<ScoringMatrix<A>>> TfmPvalue<A, M> {
         let min = (scaled - self.error_max - 1.0).floor() as i64;
 
         // Compute q values for the given scores
-        let qvalues = self.distribution(min, max);
+        self.distribution(min, max);
 
         // Compute p-values
         let mut pvalues = IntMap::default();
         let mut s = max + 1;
-        let mut last = qvalues[M - 1].keys().cloned().collect::<Vec<i64>>();
+        let mut last = self.qvalues[M - 1].keys().cloned().collect::<Vec<i64>>();
         last.sort_unstable_by(|x, y| x.partial_cmp(y).unwrap());
-        let mut sum = qvalues[0].get(&(max + 1)).cloned().unwrap_or_default();
+        let mut sum = self.qvalues[0].get(&(max + 1)).cloned().unwrap_or_default();
         for &l in last.iter().rev() {
-            sum += qvalues[M - 1][&l];
+            sum += self.qvalues[M - 1][&l];
             if l >= avg {
                 s = l;
             }
@@ -226,7 +226,11 @@ impl<A: Alphabet, M: AsRef<ScoringMatrix<A>>> TfmPvalue<A, M> {
     }
 
     /// Search the score and p-value range for a given p-value.
-    fn lookup_score(&self, pvalue: f64, range: RangeInclusive<i64>) -> (i64, RangeInclusive<f64>) {
+    fn lookup_score(
+        &mut self,
+        pvalue: f64,
+        range: RangeInclusive<i64>,
+    ) -> (i64, RangeInclusive<f64>) {
         assert!(!self.granularity.is_nan());
         let matrix = self.matrix.as_ref();
         let M: usize = matrix.len();
@@ -236,11 +240,11 @@ impl<A: Alphabet, M: AsRef<ScoringMatrix<A>>> TfmPvalue<A, M> {
         let max = *range.end();
 
         // compute q values
-        let qvalues = self.distribution(min, max);
+        self.distribution(min, max);
         let mut pvalues = IntMap::default();
 
         // find most likely scores at the end of the matrix
-        let mut keys = qvalues[M - 1].keys().cloned().collect::<Vec<_>>();
+        let mut keys = self.qvalues[M - 1].keys().cloned().collect::<Vec<_>>();
         keys.sort_unstable_by(|x, y| x.partial_cmp(y).unwrap());
 
         // compute pvalues
@@ -249,7 +253,7 @@ impl<A: Alphabet, M: AsRef<ScoringMatrix<A>>> TfmPvalue<A, M> {
         let alpha;
         let alpha_e;
         while riter > 0 {
-            sum += qvalues[M - 1][&keys[riter]];
+            sum += self.qvalues[M - 1][&keys[riter]];
             pvalues.insert(keys[riter], sum);
             if sum >= pvalue {
                 break;
-- 
GitLab