From f218b080f610784c3296223cbd72e327cbecdb18 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Fri, 30 Aug 2024 22:24:30 +0200
Subject: [PATCH] Make `lightmotif.load` support file-like objects and add unit
 tests for various formats

---
 lightmotif-py/lightmotif/__init__.py        |   1 +
 lightmotif-py/lightmotif/io.rs              | 148 +++++++++++++++++---
 lightmotif-py/lightmotif/tests/__init__.py  |   2 +
 lightmotif-py/lightmotif/tests/test_load.py |  83 +++++++++++
 4 files changed, 216 insertions(+), 18 deletions(-)
 create mode 100644 lightmotif-py/lightmotif/tests/test_load.py

diff --git a/lightmotif-py/lightmotif/__init__.py b/lightmotif-py/lightmotif/__init__.py
index 134df07..a4e2f27 100644
--- a/lightmotif-py/lightmotif/__init__.py
+++ b/lightmotif-py/lightmotif/__init__.py
@@ -12,6 +12,7 @@ from .lib import (
     create,
     stripe,
     scan,
+    load
 )
 
 __author__ = lib.__author__
diff --git a/lightmotif-py/lightmotif/io.rs b/lightmotif-py/lightmotif/io.rs
index 99b6c6d..0a5d523 100644
--- a/lightmotif-py/lightmotif/io.rs
+++ b/lightmotif-py/lightmotif/io.rs
@@ -1,20 +1,126 @@
+use std::io::BufRead;
+use std::sync::Arc;
+
 use lightmotif::abc::Alphabet;
 use lightmotif::abc::Dna;
 use lightmotif::abc::Protein;
 use lightmotif_io::error::Error;
 
+use pyo3::exceptions::PyOSError;
 use pyo3::exceptions::PyValueError;
 use pyo3::prelude::*;
+use pyo3::types::PyString;
 
 use super::CountMatrixData;
 use super::Motif;
 use super::ScoringMatrixData;
 use super::WeightMatrixData;
 
+mod pyfile {
+    use std::io::Error as IoError;
+    use std::io::Read;
+    use std::sync::Mutex;
+
+    use pyo3::exceptions::PyOSError;
+    use pyo3::exceptions::PyTypeError;
+    use pyo3::prelude::*;
+    use pyo3::types::PyBytes;
+
+    // ---------------------------------------------------------------------------
+
+    #[macro_export]
+    macro_rules! transmute_file_error {
+        ($self:ident, $e:ident, $msg:expr, $py:expr) => {{
+            // Attempt to transmute the Python OSError to an actual
+            // Rust `std::io::Error` using `from_raw_os_error`.
+            if $e.is_instance_of::<PyOSError>($py) {
+                if let Ok(code) = &$e.value_bound($py).getattr("errno") {
+                    if let Ok(n) = code.extract::<i32>() {
+                        return Err(IoError::from_raw_os_error(n));
+                    }
+                }
+            }
+
+            // if the conversion is not possible for any reason we fail
+            // silently, wrapping the Python error, and returning a
+            // generic Rust error instead.
+            $e.restore($py);
+            Err(IoError::new(std::io::ErrorKind::Other, $msg))
+        }};
+    }
+
+    // ---------------------------------------------------------------------------
+
+    /// A wrapper for a Python file that can outlive the GIL.
+    pub struct PyFileRead {
+        file: Mutex<PyObject>,
+    }
+
+    impl PyFileRead {
+        pub fn from_ref(file: &Bound<PyAny>) -> PyResult<PyFileRead> {
+            let res = file.call_method1("read", (0,))?;
+            if res.downcast::<PyBytes>().is_ok() {
+                Ok(PyFileRead {
+                    file: Mutex::new(file.to_object(file.py())),
+                })
+            } else {
+                let ty = res.get_type().name()?.to_string();
+                Err(PyTypeError::new_err(format!(
+                    "expected bytes, found {}",
+                    ty
+                )))
+            }
+        }
+    }
+
+    impl Read for PyFileRead {
+        fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
+            Python::with_gil(|py| {
+                let file = self.file.lock().expect("failed to lock file");
+                match file
+                    .to_object(py)
+                    .call_method1(py, pyo3::intern!(py, "read"), (buf.len(),))
+                {
+                    Ok(obj) => {
+                        // Check `fh.read` returned bytes, else raise a `TypeError`.
+                        if let Ok(bytes) = obj.downcast_bound::<PyBytes>(py) {
+                            let b = bytes.as_bytes();
+                            (&mut buf[..b.len()]).copy_from_slice(b);
+                            Ok(b.len())
+                        } else {
+                            let ty = obj.bind(py).get_type().name()?.to_string();
+                            let msg = format!("expected bytes, found {}", ty);
+                            PyTypeError::new_err(msg).restore(py);
+                            Err(IoError::new(
+                                std::io::ErrorKind::Other,
+                                "fh.read did not return bytes",
+                            ))
+                        }
+                    }
+                    Err(e) => transmute_file_error!(self, e, "read method failed", py),
+                }
+            })
+        }
+    }
+}
+
+fn convert_error(error: Error) -> PyErr {
+    match error {
+        Error::InvalidData => PyValueError::new_err("invalid data"),
+        Error::Io(err) => Arc::into_inner(err)
+            .map(PyErr::from)
+            .unwrap_or_else(|| PyOSError::new_err("unknown error")),
+        Error::Nom(err) => PyValueError::new_err(format!("failed to parse input: {}", err)),
+    }
+}
+
 fn convert_jaspar(record: Result<lightmotif_io::jaspar::Record, Error>) -> PyResult<Motif> {
-    let record = record.unwrap();
+    let record = record.map_err(convert_error)?;
+    let name = record.id().to_string();
     let counts = record.into();
-    Python::with_gil(|py| Motif::from_counts(py, counts))
+    let mut motif = Python::with_gil(|py| Motif::from_counts(py, counts))?;
+    motif.name = Some(name);
+    Ok(motif)
 }
 
 fn convert_jaspar16<A>(record: Result<lightmotif_io::jaspar16::Record<A>, Error>) -> PyResult<Motif>
@@ -24,10 +130,9 @@ where
     WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
     ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
 {
-    let record = record.unwrap();
+    let record = record.map_err(convert_error)?;
     let name = record.id().to_string();
     let counts = record.into_matrix();
-
     let mut motif = Python::with_gil(|py| Motif::from_counts(py, counts))?;
     motif.name = Some(name);
     Ok(motif)
@@ -40,7 +145,7 @@ where
     WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
     ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
 {
-    let record = record.unwrap();
+    let record = record.map_err(convert_error)?;
     let name = record.accession().or(record.id()).map(String::from);
     let counts = record
         .to_counts()
@@ -56,7 +161,7 @@ where
     WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
     ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
 {
-    let record = record.unwrap();
+    let record = record.map_err(convert_error)?;
     let name = record.id().to_string();
     let freqs = record.into_matrix();
     let weights = freqs.to_weight(None);
@@ -81,12 +186,19 @@ impl Loader {
     }
 }
 
+/// Load the motifs contained in a file.
 #[pyfunction]
-#[pyo3(signature = (path, format="jaspar", protein=false))]
-pub fn load(path: &str, format: &str, protein: bool) -> PyResult<Loader> {
-    let file = std::fs::File::open(path)
-        .map(std::io::BufReader::new)
-        .unwrap();
+#[pyo3(signature = (file, format="jaspar", *, protein=false))]
+pub fn load(file: Bound<PyAny>, format: &str, protein: bool) -> PyResult<Loader> {
+    let b: Box<dyn BufRead + Send> = if let Ok(s) = file.downcast::<PyString>() {
+        std::fs::File::open(s.to_str()?)
+            .map(std::io::BufReader::new)
+            .map(Box::new)?
+    } else {
+        pyfile::PyFileRead::from_ref(&file)
+            .map(std::io::BufReader::new)
+            .map(Box::new)?
+    };
     let reader: Box<dyn Iterator<Item = PyResult<Motif>> + Send> = match format {
         "jaspar" if protein => {
             return Err(PyValueError::new_err(
@@ -94,23 +206,23 @@ pub fn load(path: &str, format: &str, protein: bool) -> PyResult<Loader> {
             ))
         }
         "jaspar16" if protein => Box::new(
-            lightmotif_io::jaspar16::read::<_, Protein>(file).map(convert_jaspar16::<Protein>),
+            lightmotif_io::jaspar16::read::<_, Protein>(b).map(convert_jaspar16::<Protein>),
         ),
         "transfac" if protein => Box::new(
-            lightmotif_io::transfac::read::<_, Protein>(file).map(convert_transfac::<Protein>),
+            lightmotif_io::transfac::read::<_, Protein>(b).map(convert_transfac::<Protein>),
         ),
         "uniprobe" if protein => Box::new(
-            lightmotif_io::uniprobe::read::<_, Protein>(file).map(convert_uniprobe::<Protein>),
+            lightmotif_io::uniprobe::read::<_, Protein>(b).map(convert_uniprobe::<Protein>),
         ),
-        "jaspar" => Box::new(lightmotif_io::jaspar::read(file).map(convert_jaspar)),
+        "jaspar" => Box::new(lightmotif_io::jaspar::read(b).map(convert_jaspar)),
         "jaspar16" => {
-            Box::new(lightmotif_io::jaspar16::read::<_, Dna>(file).map(convert_jaspar16::<Dna>))
+            Box::new(lightmotif_io::jaspar16::read::<_, Dna>(b).map(convert_jaspar16::<Dna>))
         }
         "transfac" => {
-            Box::new(lightmotif_io::transfac::read::<_, Dna>(file).map(convert_transfac::<Dna>))
+            Box::new(lightmotif_io::transfac::read::<_, Dna>(b).map(convert_transfac::<Dna>))
         }
         "uniprobe" => {
-            Box::new(lightmotif_io::uniprobe::read::<_, Dna>(file).map(convert_uniprobe::<Dna>))
+            Box::new(lightmotif_io::uniprobe::read::<_, Dna>(b).map(convert_uniprobe::<Dna>))
         }
         _ => return Err(PyValueError::new_err(format!("invalid format: {}", format))),
     };
diff --git a/lightmotif-py/lightmotif/tests/__init__.py b/lightmotif-py/lightmotif/tests/__init__.py
index 8980020..53f6b47 100644
--- a/lightmotif-py/lightmotif/tests/__init__.py
+++ b/lightmotif-py/lightmotif/tests/__init__.py
@@ -4,6 +4,7 @@ from . import (
     test_scanner, 
     test_sequence, 
     test_pvalue,
+    test_load
 )
 
 
@@ -13,4 +14,5 @@ def load_tests(loader, suite, pattern):
     suite.addTests(loader.loadTestsFromModule(test_pvalue))
     suite.addTests(loader.loadTestsFromModule(test_scanner))
     suite.addTests(loader.loadTestsFromModule(test_sequence))
+    suite.addTests(loader.loadTestsFromModule(test_load))
     return suite
diff --git a/lightmotif-py/lightmotif/tests/test_load.py b/lightmotif-py/lightmotif/tests/test_load.py
new file mode 100644
index 0000000..27657cb
--- /dev/null
+++ b/lightmotif-py/lightmotif/tests/test_load.py
@@ -0,0 +1,83 @@
+import unittest
+import io
+import textwrap
+import tempfile
+
+import lightmotif
+
+
+class _TestLoad():
+    
+    def test_load_file(self):
+        text = textwrap.dedent(self.text).encode()
+        buffer = io.BytesIO(text)
+        motifs = list(lightmotif.load(buffer, self.format))
+        self.assertEqual(len(motifs), self.length)
+        self.assertEqual(motifs[0].name, self.first)
+
+    def test_load_filename(self):
+        text = textwrap.dedent(self.text).encode()
+        with tempfile.NamedTemporaryFile("r+b") as f:
+            f.write(text)
+            f.flush()
+            motifs = list(lightmotif.load(f.name, self.format))
+        self.assertEqual(len(motifs), self.length)
+        self.assertEqual(motifs[0].name, self.first)
+
+
+class TestJASPAR(_TestLoad, unittest.TestCase):
+    format = "jaspar"
+    length = 1
+    first = "MA0001.3"
+    text = """
+    >MA0001.3	AGL3
+    0      0     82     40     56     35     65     25     64      0
+    92    79      1      4      0      0      1      4      0      0
+    0      0      2      3      1      0      4      3     28     92
+    3     16     10     48     38     60     25     63      3      3
+    """
+
+class TestJASPAR16(_TestLoad, unittest.TestCase):
+    format = "jaspar16"
+    length = 2
+    first = "MA0001.3"
+    text = """
+    >MA0001.3	AGL3
+    A  [     0      0     82     40     56     35     65     25     64      0 ]
+    C  [    92     79      1      4      0      0      1      4      0      0 ]
+    G  [     0      0      2      3      1      0      4      3     28     92 ]
+    T  [     3     16     10     48     38     60     25     63      3      3 ]
+    >MA0017.3	NR2F1
+    A  [  7266   6333   8496      0      0      0      0  12059   5116   3229   3276   5681 ]
+    C  [  1692    387      0     30      0      0  12059      0   3055   2966   2470   2912 ]
+    G  [  1153   4869   3791  12059  12059      0      0     91   1618   4395   3886  36863 ]
+    T  [  1948    469      0      0      0  12059      0      0   2270   1469   2427   3466 ]
+    """
+
+class TestTRANSFAC(_TestLoad, unittest.TestCase):
+    format = "transfac"
+    length = 1
+    first = "M00005"
+    text = """
+    AC  M00005
+    P0      A      C      G      T
+    01      3      0      0      2      W
+    02      1      1      3      0      G
+    03      3      1      1      0      A
+    04      2      1      2      0      R
+    05      1      2      0      2      Y
+    06      0      5      0      0      C
+    07      5      0      0      0      A
+    08      0      0      5      0      G
+    09      0      5      0      0      C
+    10      0      0      1      4      T
+    11      0      1      4      0      G
+    12      0      2      1      2      Y
+    13      1      0      3      1      G
+    14      0      0      5      0      G
+    15      1      1      1      2      N
+    16      1      4      0      0      C
+    17      2      1      1      1      N
+    18      0      0      3      2      K
+    //
+    """.lstrip("\n")
\ No newline at end of file
-- 
GitLab