From f218b080f610784c3296223cbd72e327cbecdb18 Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Fri, 30 Aug 2024 22:24:30 +0200 Subject: [PATCH] Make `lightmotif.load` support file-like objects and add unit tests for various formats --- lightmotif-py/lightmotif/__init__.py | 1 + lightmotif-py/lightmotif/io.rs | 148 +++++++++++++++++--- lightmotif-py/lightmotif/tests/__init__.py | 2 + lightmotif-py/lightmotif/tests/test_load.py | 83 +++++++++++ 4 files changed, 216 insertions(+), 18 deletions(-) create mode 100644 lightmotif-py/lightmotif/tests/test_load.py diff --git a/lightmotif-py/lightmotif/__init__.py b/lightmotif-py/lightmotif/__init__.py index 134df07..a4e2f27 100644 --- a/lightmotif-py/lightmotif/__init__.py +++ b/lightmotif-py/lightmotif/__init__.py @@ -12,6 +12,7 @@ from .lib import ( create, stripe, scan, + load ) __author__ = lib.__author__ diff --git a/lightmotif-py/lightmotif/io.rs b/lightmotif-py/lightmotif/io.rs index 99b6c6d..0a5d523 100644 --- a/lightmotif-py/lightmotif/io.rs +++ b/lightmotif-py/lightmotif/io.rs @@ -1,20 +1,126 @@ +use std::io::BufRead; +use std::sync::Arc; + use lightmotif::abc::Alphabet; use lightmotif::abc::Dna; use lightmotif::abc::Protein; use lightmotif_io::error::Error; +use pyo3::exceptions::PyOSError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use pyo3::types::PyString; use super::CountMatrixData; use super::Motif; use super::ScoringMatrixData; use super::WeightMatrixData; +mod pyfile { + use std::io::Error as IoError; + use std::io::Read; + use std::sync::Mutex; + + use pyo3::exceptions::PyOSError; + use pyo3::exceptions::PyTypeError; + use pyo3::prelude::*; + use pyo3::types::PyBytes; + + // --------------------------------------------------------------------------- + + #[macro_export] + macro_rules! transmute_file_error { + ($self:ident, $e:ident, $msg:expr, $py:expr) => {{ + // Attempt to transmute the Python OSError to an actual + // Rust `std::io::Error` using `from_raw_os_error`. + if $e.is_instance_of::<PyOSError>($py) { + if let Ok(code) = &$e.value_bound($py).getattr("errno") { + if let Ok(n) = code.extract::<i32>() { + return Err(IoError::from_raw_os_error(n)); + } + } + } + + // if the conversion is not possible for any reason we fail + // silently, wrapping the Python error, and returning a + // generic Rust error instead. + $e.restore($py); + Err(IoError::new(std::io::ErrorKind::Other, $msg)) + }}; + } + + // --------------------------------------------------------------------------- + + /// A wrapper for a Python file that can outlive the GIL. + pub struct PyFileRead { + file: Mutex<PyObject>, + } + + impl PyFileRead { + pub fn from_ref(file: &Bound<PyAny>) -> PyResult<PyFileRead> { + let res = file.call_method1("read", (0,))?; + if res.downcast::<PyBytes>().is_ok() { + Ok(PyFileRead { + file: Mutex::new(file.to_object(file.py())), + }) + } else { + let ty = res.get_type().name()?.to_string(); + Err(PyTypeError::new_err(format!( + "expected bytes, found {}", + ty + ))) + } + } + } + + impl Read for PyFileRead { + fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> { + Python::with_gil(|py| { + let file = self.file.lock().expect("failed to lock file"); + match file + .to_object(py) + .call_method1(py, pyo3::intern!(py, "read"), (buf.len(),)) + { + Ok(obj) => { + // Check `fh.read` returned bytes, else raise a `TypeError`. + if let Ok(bytes) = obj.downcast_bound::<PyBytes>(py) { + let b = bytes.as_bytes(); + (&mut buf[..b.len()]).copy_from_slice(b); + Ok(b.len()) + } else { + let ty = obj.bind(py).get_type().name()?.to_string(); + let msg = format!("expected bytes, found {}", ty); + PyTypeError::new_err(msg).restore(py); + Err(IoError::new( + std::io::ErrorKind::Other, + "fh.read did not return bytes", + )) + } + } + Err(e) => transmute_file_error!(self, e, "read method failed", py), + } + }) + } + } +} + +fn convert_error(error: Error) -> PyErr { + match error { + Error::InvalidData => PyValueError::new_err("invalid data"), + Error::Io(err) => Arc::into_inner(err) + .map(PyErr::from) + .unwrap_or_else(|| PyOSError::new_err("unknown error")), + Error::Nom(err) => PyValueError::new_err(format!("failed to parse input: {}", err)), + } +} + fn convert_jaspar(record: Result<lightmotif_io::jaspar::Record, Error>) -> PyResult<Motif> { - let record = record.unwrap(); + let record = record.map_err(convert_error)?; + let name = record.id().to_string(); let counts = record.into(); - Python::with_gil(|py| Motif::from_counts(py, counts)) + let mut motif = Python::with_gil(|py| Motif::from_counts(py, counts))?; + motif.name = Some(name); + Ok(motif) } fn convert_jaspar16<A>(record: Result<lightmotif_io::jaspar16::Record<A>, Error>) -> PyResult<Motif> @@ -24,10 +130,9 @@ where WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>, ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>, { - let record = record.unwrap(); + let record = record.map_err(convert_error)?; let name = record.id().to_string(); let counts = record.into_matrix(); - let mut motif = Python::with_gil(|py| Motif::from_counts(py, counts))?; motif.name = Some(name); Ok(motif) @@ -40,7 +145,7 @@ where WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>, ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>, { - let record = record.unwrap(); + let record = record.map_err(convert_error)?; let name = record.accession().or(record.id()).map(String::from); let counts = record .to_counts() @@ -56,7 +161,7 @@ where WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>, ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>, { - let record = record.unwrap(); + let record = record.map_err(convert_error)?; let name = record.id().to_string(); let freqs = record.into_matrix(); let weights = freqs.to_weight(None); @@ -81,12 +186,19 @@ impl Loader { } } +/// Load the motifs contained in a file. #[pyfunction] -#[pyo3(signature = (path, format="jaspar", protein=false))] -pub fn load(path: &str, format: &str, protein: bool) -> PyResult<Loader> { - let file = std::fs::File::open(path) - .map(std::io::BufReader::new) - .unwrap(); +#[pyo3(signature = (file, format="jaspar", *, protein=false))] +pub fn load(file: Bound<PyAny>, format: &str, protein: bool) -> PyResult<Loader> { + let b: Box<dyn BufRead + Send> = if let Ok(s) = file.downcast::<PyString>() { + std::fs::File::open(s.to_str()?) + .map(std::io::BufReader::new) + .map(Box::new)? + } else { + pyfile::PyFileRead::from_ref(&file) + .map(std::io::BufReader::new) + .map(Box::new)? + }; let reader: Box<dyn Iterator<Item = PyResult<Motif>> + Send> = match format { "jaspar" if protein => { return Err(PyValueError::new_err( @@ -94,23 +206,23 @@ pub fn load(path: &str, format: &str, protein: bool) -> PyResult<Loader> { )) } "jaspar16" if protein => Box::new( - lightmotif_io::jaspar16::read::<_, Protein>(file).map(convert_jaspar16::<Protein>), + lightmotif_io::jaspar16::read::<_, Protein>(b).map(convert_jaspar16::<Protein>), ), "transfac" if protein => Box::new( - lightmotif_io::transfac::read::<_, Protein>(file).map(convert_transfac::<Protein>), + lightmotif_io::transfac::read::<_, Protein>(b).map(convert_transfac::<Protein>), ), "uniprobe" if protein => Box::new( - lightmotif_io::uniprobe::read::<_, Protein>(file).map(convert_uniprobe::<Protein>), + lightmotif_io::uniprobe::read::<_, Protein>(b).map(convert_uniprobe::<Protein>), ), - "jaspar" => Box::new(lightmotif_io::jaspar::read(file).map(convert_jaspar)), + "jaspar" => Box::new(lightmotif_io::jaspar::read(b).map(convert_jaspar)), "jaspar16" => { - Box::new(lightmotif_io::jaspar16::read::<_, Dna>(file).map(convert_jaspar16::<Dna>)) + Box::new(lightmotif_io::jaspar16::read::<_, Dna>(b).map(convert_jaspar16::<Dna>)) } "transfac" => { - Box::new(lightmotif_io::transfac::read::<_, Dna>(file).map(convert_transfac::<Dna>)) + Box::new(lightmotif_io::transfac::read::<_, Dna>(b).map(convert_transfac::<Dna>)) } "uniprobe" => { - Box::new(lightmotif_io::uniprobe::read::<_, Dna>(file).map(convert_uniprobe::<Dna>)) + Box::new(lightmotif_io::uniprobe::read::<_, Dna>(b).map(convert_uniprobe::<Dna>)) } _ => return Err(PyValueError::new_err(format!("invalid format: {}", format))), }; diff --git a/lightmotif-py/lightmotif/tests/__init__.py b/lightmotif-py/lightmotif/tests/__init__.py index 8980020..53f6b47 100644 --- a/lightmotif-py/lightmotif/tests/__init__.py +++ b/lightmotif-py/lightmotif/tests/__init__.py @@ -4,6 +4,7 @@ from . import ( test_scanner, test_sequence, test_pvalue, + test_load ) @@ -13,4 +14,5 @@ def load_tests(loader, suite, pattern): suite.addTests(loader.loadTestsFromModule(test_pvalue)) suite.addTests(loader.loadTestsFromModule(test_scanner)) suite.addTests(loader.loadTestsFromModule(test_sequence)) + suite.addTests(loader.loadTestsFromModule(test_load)) return suite diff --git a/lightmotif-py/lightmotif/tests/test_load.py b/lightmotif-py/lightmotif/tests/test_load.py new file mode 100644 index 0000000..27657cb --- /dev/null +++ b/lightmotif-py/lightmotif/tests/test_load.py @@ -0,0 +1,83 @@ +import unittest +import io +import textwrap +import tempfile + +import lightmotif + + +class _TestLoad(): + + def test_load_file(self): + text = textwrap.dedent(self.text).encode() + buffer = io.BytesIO(text) + motifs = list(lightmotif.load(buffer, self.format)) + self.assertEqual(len(motifs), self.length) + self.assertEqual(motifs[0].name, self.first) + + def test_load_filename(self): + text = textwrap.dedent(self.text).encode() + with tempfile.NamedTemporaryFile("r+b") as f: + f.write(text) + f.flush() + motifs = list(lightmotif.load(f.name, self.format)) + self.assertEqual(len(motifs), self.length) + self.assertEqual(motifs[0].name, self.first) + + +class TestJASPAR(_TestLoad, unittest.TestCase): + format = "jaspar" + length = 1 + first = "MA0001.3" + text = """ + >MA0001.3 AGL3 + 0 0 82 40 56 35 65 25 64 0 + 92 79 1 4 0 0 1 4 0 0 + 0 0 2 3 1 0 4 3 28 92 + 3 16 10 48 38 60 25 63 3 3 + """ + +class TestJASPAR16(_TestLoad, unittest.TestCase): + format = "jaspar16" + length = 2 + first = "MA0001.3" + text = """ + >MA0001.3 AGL3 + A [ 0 0 82 40 56 35 65 25 64 0 ] + C [ 92 79 1 4 0 0 1 4 0 0 ] + G [ 0 0 2 3 1 0 4 3 28 92 ] + T [ 3 16 10 48 38 60 25 63 3 3 ] + >MA0017.3 NR2F1 + A [ 7266 6333 8496 0 0 0 0 12059 5116 3229 3276 5681 ] + C [ 1692 387 0 30 0 0 12059 0 3055 2966 2470 2912 ] + G [ 1153 4869 3791 12059 12059 0 0 91 1618 4395 3886 36863 ] + T [ 1948 469 0 0 0 12059 0 0 2270 1469 2427 3466 ] + """ + +class TestTRANSFAC(_TestLoad, unittest.TestCase): + format = "transfac" + length = 1 + first = "M00005" + text = """ + AC M00005 + P0 A C G T + 01 3 0 0 2 W + 02 1 1 3 0 G + 03 3 1 1 0 A + 04 2 1 2 0 R + 05 1 2 0 2 Y + 06 0 5 0 0 C + 07 5 0 0 0 A + 08 0 0 5 0 G + 09 0 5 0 0 C + 10 0 0 1 4 T + 11 0 1 4 0 G + 12 0 2 1 2 Y + 13 1 0 3 1 G + 14 0 0 5 0 G + 15 1 1 1 2 N + 16 1 4 0 0 C + 17 2 1 1 1 N + 18 0 0 3 2 K + // + """.lstrip("\n") \ No newline at end of file -- GitLab