From 1e43a1715815be1aee86f8c66e923ee5ccd921ac Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Fri, 30 Aug 2024 21:17:39 +0200
Subject: [PATCH] Add `load` function to `lightmotif-py` to support loading
 several `Motif` from a file

---
 lightmotif-io/src/jaspar/mod.rs   |   6 ++
 lightmotif-io/src/jaspar16/mod.rs |   5 ++
 lightmotif-io/src/uniprobe/mod.rs |   5 ++
 lightmotif-py/Cargo.toml          |   3 +
 lightmotif-py/lightmotif/io.rs    | 108 ++++++++++++++++++++++++++++++
 lightmotif-py/lightmotif/lib.rs   |  43 +++++++++++-
 6 files changed, 168 insertions(+), 2 deletions(-)
 create mode 100644 lightmotif-py/lightmotif/io.rs

diff --git a/lightmotif-io/src/jaspar/mod.rs b/lightmotif-io/src/jaspar/mod.rs
index 4c69a67..98168f7 100644
--- a/lightmotif-io/src/jaspar/mod.rs
+++ b/lightmotif-io/src/jaspar/mod.rs
@@ -60,6 +60,12 @@ impl AsRef<CountMatrix<Dna>> for Record {
     }
 }
 
+impl From<Record> for CountMatrix<Dna> {
+    fn from(value: Record) -> Self {
+        value.matrix
+    }
+}
+
 // ---
 
 /// An iterative reader for the JASPAR format.
diff --git a/lightmotif-io/src/jaspar16/mod.rs b/lightmotif-io/src/jaspar16/mod.rs
index e1ec1d5..d302ab7 100644
--- a/lightmotif-io/src/jaspar16/mod.rs
+++ b/lightmotif-io/src/jaspar16/mod.rs
@@ -49,6 +49,11 @@ impl<A: Alphabet> Record<A> {
     pub fn matrix(&self) -> &CountMatrix<A> {
         &self.matrix
     }
+
+    /// Take the count matrix of the record.
+    pub fn into_matrix(self) -> CountMatrix<A> {
+        self.matrix
+    }
 }
 
 impl<A: Alphabet> AsRef<CountMatrix<A>> for Record<A> {
diff --git a/lightmotif-io/src/uniprobe/mod.rs b/lightmotif-io/src/uniprobe/mod.rs
index 8254626..7a0797c 100644
--- a/lightmotif-io/src/uniprobe/mod.rs
+++ b/lightmotif-io/src/uniprobe/mod.rs
@@ -39,6 +39,11 @@ impl<A: Alphabet> Record<A> {
     pub fn matrix(&self) -> &FrequencyMatrix<A> {
         &self.matrix
     }
+
+    /// Take the frequency matrix of the record.
+    pub fn into_matrix(self) -> FrequencyMatrix<A> {
+        self.matrix
+    }
 }
 
 impl<A: Alphabet> AsRef<FrequencyMatrix<A>> for Record<A> {
diff --git a/lightmotif-py/Cargo.toml b/lightmotif-py/Cargo.toml
index c18617b..59d9bf3 100644
--- a/lightmotif-py/Cargo.toml
+++ b/lightmotif-py/Cargo.toml
@@ -19,6 +19,9 @@ doctest = false
 [dependencies.lightmotif]
 path = "../lightmotif"
 version = "0.8.0"
+[dependencies.lightmotif-io]
+path = "../lightmotif-io"
+version = "0.8.0"
 [dependencies.lightmotif-tfmpvalue]
 optional = true
 path = "../lightmotif-tfmpvalue"
diff --git a/lightmotif-py/lightmotif/io.rs b/lightmotif-py/lightmotif/io.rs
new file mode 100644
index 0000000..8976169
--- /dev/null
+++ b/lightmotif-py/lightmotif/io.rs
@@ -0,0 +1,108 @@
+use lightmotif::abc::Alphabet;
+use lightmotif::abc::Dna;
+use lightmotif::abc::Protein;
+use lightmotif_io::error::Error;
+
+use pyo3::exceptions::PyValueError;
+use pyo3::prelude::*;
+
+use super::CountMatrixData;
+use super::Motif;
+use super::ScoringMatrixData;
+use super::WeightMatrixData;
+
+fn convert_jaspar(record: Result<lightmotif_io::jaspar::Record, Error>) -> PyResult<Motif> {
+    let record = record.unwrap();
+    let counts = record.into();
+    Python::with_gil(|py| Motif::from_counts(py, counts))
+}
+
+fn convert_jaspar16<A>(record: Result<lightmotif_io::jaspar16::Record<A>, Error>) -> PyResult<Motif>
+where
+    A: Alphabet,
+    CountMatrixData: From<lightmotif::pwm::CountMatrix<A>>,
+    WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
+    ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
+{
+    let record = record.unwrap();
+    let counts = record.into_matrix();
+    Python::with_gil(|py| Motif::from_counts(py, counts))
+}
+
+fn convert_transfac<A>(record: Result<lightmotif_io::transfac::Record<A>, Error>) -> PyResult<Motif>
+where
+    A: Alphabet,
+    CountMatrixData: From<lightmotif::pwm::CountMatrix<A>>,
+    WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
+    ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
+{
+    let record = record.unwrap();
+    let counts = record
+        .to_counts()
+        .ok_or_else(|| PyValueError::new_err("invalid count matrix"))?;
+    Python::with_gil(|py| Motif::from_counts(py, counts))
+}
+
+fn convert_uniprobe<A>(record: Result<lightmotif_io::uniprobe::Record<A>, Error>) -> PyResult<Motif>
+where
+    A: Alphabet,
+    WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
+    ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
+{
+    let record = record.unwrap();
+    let freqs = record.into_matrix();
+    let weights = freqs.to_weight(None);
+    Python::with_gil(|py| Motif::from_weights(py, weights))
+}
+
+#[pyclass(module = "lightmotif.lib")]
+pub struct Loader {
+    reader: Box<dyn Iterator<Item = PyResult<Motif>> + Send>,
+}
+
+#[pymethods]
+impl Loader {
+    fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
+        slf
+    }
+
+    fn __next__(mut slf: PyRefMut<Self>) -> Option<PyResult<Motif>> {
+        slf.reader.next()
+    }
+}
+
+#[pyfunction]
+#[pyo3(signature = (path, format="jaspar", protein=false))]
+pub fn load(path: &str, format: &str, protein: bool) -> PyResult<Loader> {
+    let file = std::fs::File::open(path)
+        .map(std::io::BufReader::new)
+        .unwrap();
+    let reader: Box<dyn Iterator<Item = PyResult<Motif>> + Send> = match format {
+        "jaspar" if protein => {
+            return Err(PyValueError::new_err(
+                "cannot read protein motifs from JASPAR format",
+            ))
+        }
+        "jaspar16" if protein => Box::new(
+            lightmotif_io::jaspar16::read::<_, Protein>(file).map(convert_jaspar16::<Protein>),
+        ),
+        "transfac" if protein => Box::new(
+            lightmotif_io::transfac::read::<_, Protein>(file).map(convert_transfac::<Protein>),
+        ),
+        "uniprobe" if protein => Box::new(
+            lightmotif_io::uniprobe::read::<_, Protein>(file).map(convert_uniprobe::<Protein>),
+        ),
+        "jaspar" => Box::new(lightmotif_io::jaspar::read(file).map(convert_jaspar)),
+        "jaspar16" => {
+            Box::new(lightmotif_io::jaspar16::read::<_, Dna>(file).map(convert_jaspar16::<Dna>))
+        }
+        "transfac" => {
+            Box::new(lightmotif_io::transfac::read::<_, Dna>(file).map(convert_transfac::<Dna>))
+        }
+        "uniprobe" => {
+            Box::new(lightmotif_io::uniprobe::read::<_, Dna>(file).map(convert_uniprobe::<Dna>))
+        }
+        _ => return Err(PyValueError::new_err(format!("invalid format: {}", format))),
+    };
+    Ok(Loader { reader })
+}
diff --git a/lightmotif-py/lightmotif/lib.rs b/lightmotif-py/lightmotif/lib.rs
index 2ac43b1..79bc65f 100644
--- a/lightmotif-py/lightmotif/lib.rs
+++ b/lightmotif-py/lightmotif/lib.rs
@@ -34,6 +34,10 @@ use pyo3::types::PyDict;
 use pyo3::types::PyList;
 use pyo3::types::PyString;
 
+mod io;
+
+// --- Macros ------------------------------------------------------------------
+
 macro_rules! impl_matrix_methods {
     ($datatype:ident) => {
         impl $datatype {
@@ -1040,13 +1044,45 @@ impl From<lightmotif::scores::StripedScores<f32, C>> for StripedScores {
 #[derive(Debug)]
 pub struct Motif {
     #[pyo3(get)]
-    counts: Py<CountMatrix>,
+    counts: Option<Py<CountMatrix>>,
     #[pyo3(get)]
     pwm: Py<WeightMatrix>,
     #[pyo3(get)]
     pssm: Py<ScoringMatrix>,
 }
 
+impl Motif {
+    fn from_counts<A>(py: Python, counts: lightmotif::pwm::CountMatrix<A>) -> PyResult<Self>
+    where
+        A: Alphabet,
+        CountMatrixData: From<lightmotif::pwm::CountMatrix<A>>,
+        WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
+        ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
+    {
+        let weights = counts.to_freq(0.0).to_weight(None);
+        let scoring = weights.to_scoring();
+        Ok(Motif {
+            counts: Some(Py::new(py, CountMatrix::new(counts))?),
+            pwm: Py::new(py, WeightMatrix::new(weights))?,
+            pssm: Py::new(py, ScoringMatrix::new(scoring))?,
+        })
+    }
+
+    fn from_weights<A>(py: Python, weights: lightmotif::pwm::WeightMatrix<A>) -> PyResult<Self>
+    where
+        A: Alphabet,
+        WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
+        ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
+    {
+        let scoring = weights.to_scoring();
+        Ok(Motif {
+            counts: None,
+            pwm: Py::new(py, WeightMatrix::new(weights))?,
+            pssm: Py::new(py, ScoringMatrix::new(scoring))?,
+        })
+    }
+}
+
 // --- Scanner -----------------------------------------------------------------
 
 #[derive(Debug)]
@@ -1139,7 +1175,7 @@ pub fn create(sequences: Bound<PyAny>, protein: bool) -> PyResult<Motif> {
             let scoring = weights.to_scoring();
 
             Ok(Motif {
-                counts: Py::new(py, CountMatrix::new(data))?,
+                counts: Some(Py::new(py, CountMatrix::new(data))?),
                 pwm: Py::new(py, WeightMatrix::new(weights))?,
                 pssm: Py::new(py, ScoringMatrix::new(scoring))?,
             })
@@ -1251,9 +1287,12 @@ pub fn init<'py>(_py: Python<'py>, m: &Bound<PyModule>) -> PyResult<()> {
     m.add_class::<Scanner>()?;
     m.add_class::<Hit>()?;
 
+    m.add_class::<io::Loader>()?;
+
     m.add_function(wrap_pyfunction!(create, m)?)?;
     m.add_function(wrap_pyfunction!(scan, m)?)?;
     m.add_function(wrap_pyfunction!(stripe, m)?)?;
+    m.add_function(wrap_pyfunction!(io::load, m)?)?;
 
     Ok(())
 }
-- 
GitLab