From 84200f258429e2f6988bec8f0db91a622a1837f7 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Sat, 29 Apr 2023 12:46:51 +0200
Subject: [PATCH] Add remaining types for creating a weight matrix

---
 src/pwm.rs | 74 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 74 insertions(+)

diff --git a/src/pwm.rs b/src/pwm.rs
index 3ade48b..4717545 100644
--- a/src/pwm.rs
+++ b/src/pwm.rs
@@ -2,6 +2,7 @@ use super::abc::Alphabet;
 use super::abc::Symbol;
 use super::dense::DenseMatrix;
 use super::seq::EncodedSequence;
+use super::seq::StripedSequence;
 
 #[derive(Clone, Debug)]
 pub struct CountMatrix<A: Alphabet, const K: usize> {
@@ -10,6 +11,13 @@ pub struct CountMatrix<A: Alphabet, const K: usize> {
 }
 
 impl<A: Alphabet, const K: usize> CountMatrix<A, K> {
+    pub fn new(data: DenseMatrix<u32, K>) -> Result<Self, ()> {
+        Ok(Self {
+            data,
+            alphabet: A::default(),
+        })
+    }
+
     pub fn from_sequences<'seq, I>(sequences: I) -> Result<Self, ()>
     where
         I: IntoIterator<Item = &'seq EncodedSequence<A>>,
@@ -33,4 +41,70 @@ impl<A: Alphabet, const K: usize> CountMatrix<A, K> {
             data: data.unwrap_or_else(|| DenseMatrix::new(0)),
         })
     }
+
+    /// Build a probability matrix from this count matrix using pseudo-counts.
+    pub fn to_probability(&self, pseudo: f32) -> ProbabilityMatrix<A, K> {
+        let mut probas = DenseMatrix::new(self.data.rows());
+        for i in 0..self.data.rows() {
+            let src = &self.data[i];
+            let mut dst = &mut probas[i];
+            for (j, &x) in src.iter().enumerate() {
+                dst[j] = x as f32 + pseudo;
+            }
+            let s: f32 = dst.iter().sum();
+            for x in dst.iter_mut() {
+                *x /= s;
+            }
+        }
+        ProbabilityMatrix {
+            alphabet: self.alphabet,
+            data: probas,
+        }
+    }
+}
+
+#[derive(Clone, Debug)]
+pub struct ProbabilityMatrix<A: Alphabet, const K: usize> {
+    pub alphabet: A,
+    pub data: DenseMatrix<f32, K>,
+}
+
+impl<A: Alphabet, const K: usize> ProbabilityMatrix<A, K> {
+    pub fn to_weight(&self, background: Background<A, K>) -> WeightMatrix<A, K> {
+        let mut weight = DenseMatrix::new(self.data.rows());
+        for i in 0..self.data.rows() {
+            let src = &self.data[i];
+            let mut dst = &mut weight[i];
+            for (j, (&x, &f)) in src.iter().zip(&background.frequencies).enumerate() {
+                dst[j] = (x / f).log2();
+            }
+        }
+        WeightMatrix {
+            background,
+            alphabet: self.alphabet,
+            data: weight,
+        }
+    }
+}
+
+#[derive(Clone, Debug)]
+pub struct Background<A: Alphabet, const K: usize> {
+    pub frequencies: [f32; K],
+    _marker: std::marker::PhantomData<A>,
+}
+
+impl<A: Alphabet, const K: usize> Background<A, K> {
+    pub fn uniform() -> Self {
+        Self {
+            frequencies: [1.0 / (K as f32); K],
+            _marker: std::marker::PhantomData,
+        }
+    }
+}
+
+#[derive(Clone, Debug)]
+pub struct WeightMatrix<A: Alphabet, const K: usize> {
+    pub alphabet: A,
+    pub background: Background<A, K>,
+    pub data: DenseMatrix<f32, K>,
 }
-- 
GitLab