From c93271a6d91ad416e6e68fd6fcb19aba27892019 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Thu, 20 Jun 2024 11:32:56 +0200
Subject: [PATCH] Add `DiscreteMatrix` type to `lightmotif::pwm`

---
 lightmotif/src/pwm.rs | 150 +++++++++++++++++++++++++++++++++++-------
 1 file changed, 125 insertions(+), 25 deletions(-)

diff --git a/lightmotif/src/pwm.rs b/lightmotif/src/pwm.rs
index b1f6fc1..c4e3d5e 100644
--- a/lightmotif/src/pwm.rs
+++ b/lightmotif/src/pwm.rs
@@ -24,9 +24,15 @@ macro_rules! matrix_traits {
         impl<A: Alphabet> $mx<A> {
             /// The raw data storage for the matrix.
             #[inline]
-            pub fn matrix(&self) -> &DenseMatrix<$t, A::K> {
+            pub const fn matrix(&self) -> &DenseMatrix<$t, A::K> {
                 &self.data
             }
+
+            /// The length of the motif encoded in this matrix.
+            #[inline]
+            pub const fn len(&self) -> usize {
+                self.data.rows()
+            }
         }
 
         impl<A: Alphabet> AsRef<$mx<A>> for $mx<A> {
@@ -93,12 +99,6 @@ impl<A: Alphabet> CountMatrix<A> {
         // }
     }
 
-    /// The length of the motif encoded in this count matrix.
-    #[inline]
-    pub const fn len(&self) -> usize {
-        self.data.rows()
-    }
-
     /// Create a new count matrix from the given sequences.
     ///
     /// # Errors
@@ -239,12 +239,6 @@ impl<A: Alphabet> FrequencyMatrix<A> {
         }
     }
 
-    /// The length of the motif encoded in this frequency matrix.
-    #[inline]
-    pub const fn len(&self) -> usize {
-        self.data.rows()
-    }
-
     /// Create a new frequency matrix.
     ///
     /// The matrix must contain frequency data, i.e. rows should all sum to 1
@@ -341,12 +335,6 @@ impl<A: Alphabet> WeightMatrix<A> {
         Self { background, data }
     }
 
-    /// The length of the motif encoded in this weight matrix.
-    #[inline]
-    pub const fn len(&self) -> usize {
-        self.data.rows()
-    }
-
     /// The background frequencies of the position weight matrix.
     #[inline]
     pub fn background(&self) -> &Background<A> {
@@ -455,12 +443,6 @@ impl<A: Alphabet> ScoringMatrix<A> {
         Self { background, data }
     }
 
-    /// The length of the motif encoded in this scoring matrix.
-    #[inline]
-    pub const fn len(&self) -> usize {
-        self.data.rows()
-    }
-
     /// The background frequencies of the position weight matrix.
     #[inline]
     pub fn background(&self) -> &Background<A> {
@@ -520,6 +502,37 @@ impl<A: Alphabet> ScoringMatrix<A> {
         }
         score
     }
+
+    /// Get a discrete matrix from this position-specific scoring matrix.
+    pub fn to_discrete(&self) -> DiscreteMatrix<A> {
+        let max_score = self.max_score();
+        let offsets = self
+            .matrix()
+            .iter()
+            .map(|row| {
+                row[..A::K::USIZE - 1]
+                    .iter()
+                    .min_by(|x, y| x.partial_cmp(y).unwrap())
+                    .unwrap()
+            })
+            .cloned()
+            .collect::<Vec<f32>>();
+        let offset = offsets.iter().sum::<f32>();
+        let factor = (max_score - offset) / (u8::MAX as f32);
+        let pssm = self.matrix();
+        let mut data = DenseMatrix::new(self.len());
+        for i in 0..data.rows() {
+            for j in 0..data.columns() {
+                data[i][j] = ((pssm[i][j] - offsets[i]) / factor).ceil() as u8;
+            }
+        }
+        DiscreteMatrix {
+            data,
+            factor,
+            offsets,
+            offset,
+        }
+    }
 }
 
 impl<A: Alphabet> From<WeightMatrix<A>> for ScoringMatrix<A> {
@@ -529,3 +542,90 @@ impl<A: Alphabet> From<WeightMatrix<A>> for ScoringMatrix<A> {
 }
 
 matrix_traits!(ScoringMatrix, f32);
+
+// --- DiscreteMatrix ----------------------------------------------------------
+
+/// A position-specific scoring matrix discretized over `u8::MIN..u8::MAX`.
+///
+/// # Note
+/// The discretization is done by rounding the error *up*, so that the scores
+/// computed through a discrete matrix are over an *over*-estimation of the
+/// actual scores. This allows for the fast scanning for candidate positions,
+/// which have then to be scanned again with the full PSSM to compute the real
+/// scores.
+///
+/// # Example
+/// ```
+/// # use lightmotif::*;
+/// # let counts = CountMatrix::<Dna>::from_sequences(
+/// #  ["GTTGACCTTATCAAC", "GTTGATCCAGTCAAC"]
+/// #        .into_iter()
+/// #        .map(|s| EncodedSequence::encode(s).unwrap()),
+/// # )
+/// # .unwrap();
+/// # let pssm = counts.to_freq(0.1).to_scoring(None);
+/// # let seq = "ATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCG";
+/// # let mut striped = EncodedSequence::encode(seq).unwrap().to_striped();
+/// // Create a `DiscreteMatrix` from a `ScoringMatrix`
+/// let discrete = pssm.to_discrete();
+///
+/// // The discrete scores are always higher than the real scores.
+/// for j in 0..seq.len() - pssm.len() + 1 {
+///     let score_f32 = pssm.score_position(&striped, 1);
+///     let score_u8 = discrete.unscale(discrete.score_position(&striped, 1));
+///     assert!(score_u8 >= score_f32);
+/// }
+/// ```
+#[derive(Clone, Debug, PartialEq)]
+pub struct DiscreteMatrix<A: Alphabet> {
+    data: DenseMatrix<u8, A::K>,
+    factor: f32,
+    offsets: Vec<f32>,
+    offset: f32,
+}
+
+impl<A: Alphabet> DiscreteMatrix<A> {
+    /// Compute the score for a single sequence position.
+    pub fn score_position<S, C>(&self, seq: S, pos: usize) -> u8
+    where
+        C: StrictlyPositive,
+        S: AsRef<StripedSequence<A, C>>,
+    {
+        let mut score = 0;
+        let s = seq.as_ref();
+        for (j, row) in self.data.iter().enumerate() {
+            score += row[s[pos + j].as_index()]
+        }
+        score
+    }
+
+    /// Scale the given score to an integer score using the matrix scale.
+    ///
+    /// # Note
+    /// This function rounds down the final score, and is suitable to translate
+    /// an `f32` score threshold to a `u8` score threshold.
+    #[inline]
+    pub fn scale(&self, score: f32) -> u8 {
+        ((score - self.offset) / self.factor).floor() as u8
+    }
+
+    /// Unscale the given integer score into a score using the matrix scale.
+    #[inline]
+    pub fn unscale(&self, score: u8) -> f32 {
+        (score as f32) * self.factor + self.offset
+    }
+}
+
+impl<A: Alphabet> From<ScoringMatrix<A>> for DiscreteMatrix<A> {
+    fn from(value: ScoringMatrix<A>) -> Self {
+        Self::from(&value)
+    }
+}
+
+impl<A: Alphabet> From<&ScoringMatrix<A>> for DiscreteMatrix<A> {
+    fn from(s: &ScoringMatrix<A>) -> Self {
+        s.to_discrete()
+    }
+}
+
+matrix_traits!(DiscreteMatrix, u8);
-- 
GitLab