From ffe6ac099cad9186e29be9c8e4b78b57beda0975 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Sun, 1 Sep 2024 21:07:37 +0200
Subject: [PATCH] Add `Motif` subclasss to store metadata from JASPAR or
 TRANSFAC records

---
 lightmotif-py/lightmotif/io.rs              | 271 +++++++++-----------
 lightmotif-py/lightmotif/lib.rs             |   6 +-
 lightmotif-py/lightmotif/pyfile.rs          |  85 ++++++
 lightmotif-py/lightmotif/tests/test_load.py |   2 +-
 4 files changed, 210 insertions(+), 154 deletions(-)
 create mode 100644 lightmotif-py/lightmotif/pyfile.rs

diff --git a/lightmotif-py/lightmotif/io.rs b/lightmotif-py/lightmotif/io.rs
index 0a5d523..ea1432c 100644
--- a/lightmotif-py/lightmotif/io.rs
+++ b/lightmotif-py/lightmotif/io.rs
@@ -11,98 +11,13 @@ use pyo3::exceptions::PyValueError;
 use pyo3::prelude::*;
 use pyo3::types::PyString;
 
+use super::pyfile::PyFileRead;
 use super::CountMatrixData;
 use super::Motif;
 use super::ScoringMatrixData;
 use super::WeightMatrixData;
 
-mod pyfile {
-    use std::io::Error as IoError;
-    use std::io::Read;
-    use std::sync::Mutex;
-
-    use pyo3::exceptions::PyOSError;
-    use pyo3::exceptions::PyTypeError;
-    use pyo3::prelude::*;
-    use pyo3::types::PyBytes;
-
-    // ---------------------------------------------------------------------------
-
-    #[macro_export]
-    macro_rules! transmute_file_error {
-        ($self:ident, $e:ident, $msg:expr, $py:expr) => {{
-            // Attempt to transmute the Python OSError to an actual
-            // Rust `std::io::Error` using `from_raw_os_error`.
-            if $e.is_instance_of::<PyOSError>($py) {
-                if let Ok(code) = &$e.value_bound($py).getattr("errno") {
-                    if let Ok(n) = code.extract::<i32>() {
-                        return Err(IoError::from_raw_os_error(n));
-                    }
-                }
-            }
-
-            // if the conversion is not possible for any reason we fail
-            // silently, wrapping the Python error, and returning a
-            // generic Rust error instead.
-            $e.restore($py);
-            Err(IoError::new(std::io::ErrorKind::Other, $msg))
-        }};
-    }
-
-    // ---------------------------------------------------------------------------
-
-    /// A wrapper for a Python file that can outlive the GIL.
-    pub struct PyFileRead {
-        file: Mutex<PyObject>,
-    }
-
-    impl PyFileRead {
-        pub fn from_ref(file: &Bound<PyAny>) -> PyResult<PyFileRead> {
-            let res = file.call_method1("read", (0,))?;
-            if res.downcast::<PyBytes>().is_ok() {
-                Ok(PyFileRead {
-                    file: Mutex::new(file.to_object(file.py())),
-                })
-            } else {
-                let ty = res.get_type().name()?.to_string();
-                Err(PyTypeError::new_err(format!(
-                    "expected bytes, found {}",
-                    ty
-                )))
-            }
-        }
-    }
-
-    impl Read for PyFileRead {
-        fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
-            Python::with_gil(|py| {
-                let file = self.file.lock().expect("failed to lock file");
-                match file
-                    .to_object(py)
-                    .call_method1(py, pyo3::intern!(py, "read"), (buf.len(),))
-                {
-                    Ok(obj) => {
-                        // Check `fh.read` returned bytes, else raise a `TypeError`.
-                        if let Ok(bytes) = obj.downcast_bound::<PyBytes>(py) {
-                            let b = bytes.as_bytes();
-                            (&mut buf[..b.len()]).copy_from_slice(b);
-                            Ok(b.len())
-                        } else {
-                            let ty = obj.bind(py).get_type().name()?.to_string();
-                            let msg = format!("expected bytes, found {}", ty);
-                            PyTypeError::new_err(msg).restore(py);
-                            Err(IoError::new(
-                                std::io::ErrorKind::Other,
-                                "fh.read did not return bytes",
-                            ))
-                        }
-                    }
-                    Err(e) => transmute_file_error!(self, e, "read method failed", py),
-                }
-            })
-        }
-    }
-}
+// --- Error handling ----------------------------------------------------------
 
 fn convert_error(error: Error) -> PyErr {
     match error {
@@ -114,65 +29,117 @@ fn convert_error(error: Error) -> PyErr {
     }
 }
 
-fn convert_jaspar(record: Result<lightmotif_io::jaspar::Record, Error>) -> PyResult<Motif> {
-    let record = record.map_err(convert_error)?;
-    let name = record.id().to_string();
-    let counts = record.into();
-    let mut motif = Python::with_gil(|py| Motif::from_counts(py, counts))?;
-    motif.name = Some(name);
-    Ok(motif)
+// --- JASPAR motif ------------------------------------------------------------
+
+#[pyclass(module = "lightmotif.lib", extends = Motif)]
+pub struct JasparMotif {
+    #[pyo3(get)]
+    description: Option<String>,
+}
+
+impl JasparMotif {
+    fn convert(record: Result<lightmotif_io::jaspar::Record, Error>) -> PyResult<PyObject> {
+        let record = record.map_err(convert_error)?;
+        let name = record.id().to_string();
+        let description = record.description().map(String::from);
+        let counts = record.into();
+        Python::with_gil(|py| {
+            let mut motif = Motif::from_counts(py, counts)?;
+            motif.name = Some(name);
+            let init = PyClassInitializer::from(motif).add_subclass(JasparMotif { description });
+            Ok(Py::new(py, init)?.to_object(py))
+        })
+    }
+
+    fn convert16<A>(record: Result<lightmotif_io::jaspar16::Record<A>, Error>) -> PyResult<PyObject>
+    where
+        A: Alphabet,
+        CountMatrixData: From<lightmotif::pwm::CountMatrix<A>>,
+        WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
+        ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
+    {
+        let record = record.map_err(convert_error)?;
+        let name = record.id().to_string();
+        let description = record.description().map(String::from);
+        let counts = record.into_matrix();
+        Python::with_gil(|py| {
+            let mut motif = Motif::from_counts(py, counts)?;
+            motif.name = Some(name);
+            let init = PyClassInitializer::from(motif).add_subclass(JasparMotif { description });
+            Ok(Py::new(py, init)?.to_object(py))
+        })
+    }
 }
 
-fn convert_jaspar16<A>(record: Result<lightmotif_io::jaspar16::Record<A>, Error>) -> PyResult<Motif>
-where
-    A: Alphabet,
-    CountMatrixData: From<lightmotif::pwm::CountMatrix<A>>,
-    WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
-    ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
-{
-    let record = record.map_err(convert_error)?;
-    let name = record.id().to_string();
-    let counts = record.into_matrix();
-    let mut motif = Python::with_gil(|py| Motif::from_counts(py, counts))?;
-    motif.name = Some(name);
-    Ok(motif)
+// --- UniPROBE motif ----------------------------------------------------------
+
+#[pyclass(module = "lightmotif.lib", extends = Motif)]
+pub struct UniprobeMotif {}
+
+impl UniprobeMotif {
+    fn convert<A>(record: Result<lightmotif_io::uniprobe::Record<A>, Error>) -> PyResult<PyObject>
+    where
+        A: Alphabet,
+        WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
+        ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
+    {
+        let record = record.map_err(convert_error)?;
+        let name = record.id().to_string();
+        let freqs = record.into_matrix();
+        let weights = freqs.to_weight(None);
+        Python::with_gil(|py| {
+            let mut motif = Motif::from_weights(py, weights)?;
+            motif.name = Some(name);
+            let init = PyClassInitializer::from(motif).add_subclass(UniprobeMotif {});
+            Ok(Py::new(py, init)?.to_object(py))
+        })
+    }
 }
 
-fn convert_transfac<A>(record: Result<lightmotif_io::transfac::Record<A>, Error>) -> PyResult<Motif>
-where
-    A: Alphabet,
-    CountMatrixData: From<lightmotif::pwm::CountMatrix<A>>,
-    WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
-    ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
-{
-    let record = record.map_err(convert_error)?;
-    let name = record.accession().or(record.id()).map(String::from);
-    let counts = record
-        .to_counts()
-        .ok_or_else(|| PyValueError::new_err("invalid count matrix"))?;
-    let mut motif = Python::with_gil(|py| Motif::from_counts(py, counts))?;
-    motif.name = name;
-    Ok(motif)
+// --- TRANSFAC records --------------------------------------------------------
+
+#[pyclass(module = "lightmotif.lib", extends = Motif)]
+pub struct TransfacMotif {
+    id: Option<String>,
+    accession: Option<String>,
+    description: Option<String>,
 }
 
-fn convert_uniprobe<A>(record: Result<lightmotif_io::uniprobe::Record<A>, Error>) -> PyResult<Motif>
-where
-    A: Alphabet,
-    WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
-    ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
-{
-    let record = record.map_err(convert_error)?;
-    let name = record.id().to_string();
-    let freqs = record.into_matrix();
-    let weights = freqs.to_weight(None);
-    let mut motif = Python::with_gil(|py| Motif::from_weights(py, weights))?;
-    motif.name = Some(name);
-    Ok(motif)
+impl TransfacMotif {
+    fn convert<A>(record: Result<lightmotif_io::transfac::Record<A>, Error>) -> PyResult<PyObject>
+    where
+        A: Alphabet,
+        CountMatrixData: From<lightmotif::pwm::CountMatrix<A>>,
+        WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
+        ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
+    {
+        let record = record.map_err(convert_error)?;
+
+        let name = record.name().map(String::from);
+        let description = record.description().map(String::from);
+        let id = record.id().map(String::from);
+        let accession = record.accession().map(String::from);
+        let counts = record
+            .to_counts()
+            .ok_or_else(|| PyValueError::new_err("invalid count matrix"))?;
+        Python::with_gil(|py| {
+            let mut motif = Motif::from_counts(py, counts)?;
+            motif.name = name;
+            let init = PyClassInitializer::from(motif).add_subclass(TransfacMotif {
+                description,
+                accession,
+                id,
+            });
+            Ok(Py::new(py, init)?.to_object(py))
+        })
+    }
 }
 
+// --- Loader ------------------------------------------------------------------
+
 #[pyclass(module = "lightmotif.lib")]
 pub struct Loader {
-    reader: Box<dyn Iterator<Item = PyResult<Motif>> + Send>,
+    reader: Box<dyn Iterator<Item = PyResult<PyObject>> + Send>,
 }
 
 #[pymethods]
@@ -181,7 +148,7 @@ impl Loader {
         slf
     }
 
-    fn __next__(mut slf: PyRefMut<Self>) -> Option<PyResult<Motif>> {
+    fn __next__(mut slf: PyRefMut<Self>) -> Option<PyResult<PyObject>> {
         slf.reader.next()
     }
 }
@@ -195,34 +162,34 @@ pub fn load(file: Bound<PyAny>, format: &str, protein: bool) -> PyResult<Loader>
             .map(std::io::BufReader::new)
             .map(Box::new)?
     } else {
-        pyfile::PyFileRead::from_ref(&file)
+        PyFileRead::from_ref(&file)
             .map(std::io::BufReader::new)
             .map(Box::new)?
     };
-    let reader: Box<dyn Iterator<Item = PyResult<Motif>> + Send> = match format {
+    let reader: Box<dyn Iterator<Item = PyResult<PyObject>> + Send> = match format {
         "jaspar" if protein => {
             return Err(PyValueError::new_err(
                 "cannot read protein motifs from JASPAR format",
             ))
         }
-        "jaspar16" if protein => Box::new(
-            lightmotif_io::jaspar16::read::<_, Protein>(b).map(convert_jaspar16::<Protein>),
-        ),
-        "transfac" if protein => Box::new(
-            lightmotif_io::transfac::read::<_, Protein>(b).map(convert_transfac::<Protein>),
-        ),
-        "uniprobe" if protein => Box::new(
-            lightmotif_io::uniprobe::read::<_, Protein>(b).map(convert_uniprobe::<Protein>),
-        ),
-        "jaspar" => Box::new(lightmotif_io::jaspar::read(b).map(convert_jaspar)),
+        "jaspar16" if protein => {
+            Box::new(lightmotif_io::jaspar16::read::<_, Protein>(b).map(JasparMotif::convert16))
+        }
+        "transfac" if protein => {
+            Box::new(lightmotif_io::transfac::read::<_, Protein>(b).map(TransfacMotif::convert))
+        }
+        "uniprobe" if protein => {
+            Box::new(lightmotif_io::uniprobe::read::<_, Protein>(b).map(UniprobeMotif::convert))
+        }
+        "jaspar" => Box::new(lightmotif_io::jaspar::read(b).map(JasparMotif::convert)),
         "jaspar16" => {
-            Box::new(lightmotif_io::jaspar16::read::<_, Dna>(b).map(convert_jaspar16::<Dna>))
+            Box::new(lightmotif_io::jaspar16::read::<_, Dna>(b).map(JasparMotif::convert16))
         }
         "transfac" => {
-            Box::new(lightmotif_io::transfac::read::<_, Dna>(b).map(convert_transfac::<Dna>))
+            Box::new(lightmotif_io::transfac::read::<_, Dna>(b).map(TransfacMotif::convert))
         }
         "uniprobe" => {
-            Box::new(lightmotif_io::uniprobe::read::<_, Dna>(b).map(convert_uniprobe::<Dna>))
+            Box::new(lightmotif_io::uniprobe::read::<_, Dna>(b).map(UniprobeMotif::convert))
         }
         _ => return Err(PyValueError::new_err(format!("invalid format: {}", format))),
     };
diff --git a/lightmotif-py/lightmotif/lib.rs b/lightmotif-py/lightmotif/lib.rs
index 959abda..310aae7 100644
--- a/lightmotif-py/lightmotif/lib.rs
+++ b/lightmotif-py/lightmotif/lib.rs
@@ -37,6 +37,7 @@ use pyo3::types::PyList;
 use pyo3::types::PyString;
 
 mod io;
+mod pyfile;
 
 // --- Macros ------------------------------------------------------------------
 
@@ -1042,7 +1043,7 @@ impl From<lightmotif::scores::StripedScores<f32, C>> for StripedScores {
 
 // --- Motif -------------------------------------------------------------------
 
-#[pyclass(module = "lightmotif.lib")]
+#[pyclass(module = "lightmotif.lib", subclass)]
 #[derive(Debug)]
 pub struct Motif {
     #[pyo3(get)]
@@ -1294,6 +1295,9 @@ pub fn init<'py>(_py: Python<'py>, m: &Bound<PyModule>) -> PyResult<()> {
     m.add_class::<StripedScores>()?;
 
     m.add_class::<Motif>()?;
+    m.add_class::<io::TransfacMotif>()?;
+    m.add_class::<io::JasparMotif>()?;
+    m.add_class::<io::UniprobeMotif>()?;
 
     m.add_class::<Scanner>()?;
     m.add_class::<Hit>()?;
diff --git a/lightmotif-py/lightmotif/pyfile.rs b/lightmotif-py/lightmotif/pyfile.rs
new file mode 100644
index 0000000..4df8849
--- /dev/null
+++ b/lightmotif-py/lightmotif/pyfile.rs
@@ -0,0 +1,85 @@
+use std::io::Error as IoError;
+use std::io::Read;
+use std::sync::Mutex;
+
+use pyo3::exceptions::PyOSError;
+use pyo3::exceptions::PyTypeError;
+use pyo3::prelude::*;
+use pyo3::types::PyBytes;
+
+// ---------------------------------------------------------------------------
+
+#[macro_export]
+macro_rules! transmute_file_error {
+    ($self:ident, $e:ident, $msg:expr, $py:expr) => {{
+        // Attempt to transmute the Python OSError to an actual
+        // Rust `std::io::Error` using `from_raw_os_error`.
+        if $e.is_instance_of::<PyOSError>($py) {
+            if let Ok(code) = &$e.value_bound($py).getattr("errno") {
+                if let Ok(n) = code.extract::<i32>() {
+                    return Err(IoError::from_raw_os_error(n));
+                }
+            }
+        }
+
+        // if the conversion is not possible for any reason we fail
+        // silently, wrapping the Python error, and returning a
+        // generic Rust error instead.
+        $e.restore($py);
+        Err(IoError::new(std::io::ErrorKind::Other, $msg))
+    }};
+}
+
+// ---------------------------------------------------------------------------
+
+/// A wrapper for a Python file that can outlive the GIL.
+pub struct PyFileRead {
+    file: Mutex<PyObject>,
+}
+
+impl PyFileRead {
+    pub fn from_ref(file: &Bound<PyAny>) -> PyResult<PyFileRead> {
+        let res = file.call_method1("read", (0,))?;
+        if res.downcast::<PyBytes>().is_ok() {
+            Ok(PyFileRead {
+                file: Mutex::new(file.to_object(file.py())),
+            })
+        } else {
+            let ty = res.get_type().name()?.to_string();
+            Err(PyTypeError::new_err(format!(
+                "expected bytes, found {}",
+                ty
+            )))
+        }
+    }
+}
+
+impl Read for PyFileRead {
+    fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
+        Python::with_gil(|py| {
+            let file = self.file.lock().expect("failed to lock file");
+            match file
+                .to_object(py)
+                .call_method1(py, pyo3::intern!(py, "read"), (buf.len(),))
+            {
+                Ok(obj) => {
+                    // Check `fh.read` returned bytes, else raise a `TypeError`.
+                    if let Ok(bytes) = obj.downcast_bound::<PyBytes>(py) {
+                        let b = bytes.as_bytes();
+                        (&mut buf[..b.len()]).copy_from_slice(b);
+                        Ok(b.len())
+                    } else {
+                        let ty = obj.bind(py).get_type().name()?.to_string();
+                        let msg = format!("expected bytes, found {}", ty);
+                        PyTypeError::new_err(msg).restore(py);
+                        Err(IoError::new(
+                            std::io::ErrorKind::Other,
+                            "fh.read did not return bytes",
+                        ))
+                    }
+                }
+                Err(e) => transmute_file_error!(self, e, "read method failed", py),
+            }
+        })
+    }
+}
diff --git a/lightmotif-py/lightmotif/tests/test_load.py b/lightmotif-py/lightmotif/tests/test_load.py
index 27657cb..cb6f4b5 100644
--- a/lightmotif-py/lightmotif/tests/test_load.py
+++ b/lightmotif-py/lightmotif/tests/test_load.py
@@ -59,7 +59,7 @@ class TestTRANSFAC(_TestLoad, unittest.TestCase):
     length = 1
     first = "M00005"
     text = """
-    AC  M00005
+    NA  M00005
     P0      A      C      G      T
     01      3      0      0      2      W
     02      1      1      3      0      G
-- 
GitLab