From 7acbd712842c3c6a2a540a14eb93bbe75c607860 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Fri, 15 Dec 2023 14:16:01 +0100
Subject: [PATCH] Add `Background` constructor counting symbol occurences in a
 single sequence

---
 lightmotif/src/abc.rs | 57 +++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 57 insertions(+)

diff --git a/lightmotif/src/abc.rs b/lightmotif/src/abc.rs
index 56482a3..ec789b5 100644
--- a/lightmotif/src/abc.rs
+++ b/lightmotif/src/abc.rs
@@ -1,6 +1,7 @@
 //! Digital encoding for biological sequences using an alphabet.
 
 use std::fmt::Debug;
+use std::ops::Index;
 
 use generic_array::ArrayLength;
 use generic_array::GenericArray;
@@ -344,6 +345,54 @@ impl<A: Alphabet> Background<A> {
         })
     }
 
+    /// Create a new background by counting symbol occurences in the given sequence.
+    ///
+    /// Pass `true` as the `unknown` argument to count the unknown symbol when
+    /// computing frequencies, or `false` to only count frequencies of known
+    /// symbols.
+    ///
+    /// # Example
+    /// ```rust
+    /// # use lightmotif::abc::Background;
+    /// # use lightmotif::abc::Dna;
+    /// # use lightmotif::abc::Nucleotide::*;
+    /// let sequence = &[ T, T, A, T, G, T, T, A, C, C ];
+    /// let background = Background::<Dna>::from_sequence(sequence, false).unwrap();
+    /// assert_eq!(background[A], 0.2);
+    /// assert_eq!(background[C], 0.2);
+    /// assert_eq!(background[T], 0.5);
+    /// assert_eq!(background[G], 0.1);
+    /// ```
+    pub fn from_sequence<S>(sequence: S, unknown: bool) -> Result<Self, InvalidData>
+    where
+        S: SymbolCount<A>,
+    {
+        let n = A::default_symbol();
+        let mut total = 0;
+        let mut base_counts = GenericArray::<usize, A::K>::default();
+        for &c in A::symbols() {
+            if unknown || c != n {
+                let count = sequence.count_symbol(c);
+                total += count;
+                base_counts[c.as_index()] += count;
+            }
+        }
+
+        if total == 0 {
+            return Err(InvalidData);
+        }
+
+        let mut frequencies = GenericArray::<f32, A::K>::default();
+        for c in A::symbols() {
+            frequencies[c.as_index()] = base_counts[c.as_index()] as f32 / total as f32;
+        }
+
+        Ok(Self {
+            frequencies,
+            alphabet: std::marker::PhantomData,
+        })
+    }
+
     /// Create a new background by counting symbol occurences in the given sequences.
     ///
     /// Pass `true` as the `unknown` argument to count the unknown symbol when
@@ -437,6 +486,14 @@ impl<A: Alphabet> Default for Background<A> {
     }
 }
 
+impl<A: Alphabet> Index<A::Symbol> for Background<A> {
+    type Output = f32;
+    #[inline]
+    fn index(&self, index: A::Symbol) -> &Self::Output {
+        &self.frequencies()[index.as_index()]
+    }
+}
+
 // --- Pseudocounts ------------------------------------------------------------
 
 /// A structure for storing the pseudocounts over an alphabet.
-- 
GitLab