Skip to content
Snippets Groups Projects
Commit f218b080 authored by Martin Larralde's avatar Martin Larralde
Browse files

Make `lightmotif.load` support file-like objects and add unit tests for various formats

parent a3cfd3f4
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,7 @@ from .lib import (
create,
stripe,
scan,
load
)
__author__ = lib.__author__
......
use std::io::BufRead;
use std::sync::Arc;
use lightmotif::abc::Alphabet;
use lightmotif::abc::Dna;
use lightmotif::abc::Protein;
use lightmotif_io::error::Error;
use pyo3::exceptions::PyOSError;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyString;
use super::CountMatrixData;
use super::Motif;
use super::ScoringMatrixData;
use super::WeightMatrixData;
mod pyfile {
use std::io::Error as IoError;
use std::io::Read;
use std::sync::Mutex;
use pyo3::exceptions::PyOSError;
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
// ---------------------------------------------------------------------------
#[macro_export]
macro_rules! transmute_file_error {
($self:ident, $e:ident, $msg:expr, $py:expr) => {{
// Attempt to transmute the Python OSError to an actual
// Rust `std::io::Error` using `from_raw_os_error`.
if $e.is_instance_of::<PyOSError>($py) {
if let Ok(code) = &$e.value_bound($py).getattr("errno") {
if let Ok(n) = code.extract::<i32>() {
return Err(IoError::from_raw_os_error(n));
}
}
}
// if the conversion is not possible for any reason we fail
// silently, wrapping the Python error, and returning a
// generic Rust error instead.
$e.restore($py);
Err(IoError::new(std::io::ErrorKind::Other, $msg))
}};
}
// ---------------------------------------------------------------------------
/// A wrapper for a Python file that can outlive the GIL.
pub struct PyFileRead {
file: Mutex<PyObject>,
}
impl PyFileRead {
pub fn from_ref(file: &Bound<PyAny>) -> PyResult<PyFileRead> {
let res = file.call_method1("read", (0,))?;
if res.downcast::<PyBytes>().is_ok() {
Ok(PyFileRead {
file: Mutex::new(file.to_object(file.py())),
})
} else {
let ty = res.get_type().name()?.to_string();
Err(PyTypeError::new_err(format!(
"expected bytes, found {}",
ty
)))
}
}
}
impl Read for PyFileRead {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
Python::with_gil(|py| {
let file = self.file.lock().expect("failed to lock file");
match file
.to_object(py)
.call_method1(py, pyo3::intern!(py, "read"), (buf.len(),))
{
Ok(obj) => {
// Check `fh.read` returned bytes, else raise a `TypeError`.
if let Ok(bytes) = obj.downcast_bound::<PyBytes>(py) {
let b = bytes.as_bytes();
(&mut buf[..b.len()]).copy_from_slice(b);
Ok(b.len())
} else {
let ty = obj.bind(py).get_type().name()?.to_string();
let msg = format!("expected bytes, found {}", ty);
PyTypeError::new_err(msg).restore(py);
Err(IoError::new(
std::io::ErrorKind::Other,
"fh.read did not return bytes",
))
}
}
Err(e) => transmute_file_error!(self, e, "read method failed", py),
}
})
}
}
}
fn convert_error(error: Error) -> PyErr {
match error {
Error::InvalidData => PyValueError::new_err("invalid data"),
Error::Io(err) => Arc::into_inner(err)
.map(PyErr::from)
.unwrap_or_else(|| PyOSError::new_err("unknown error")),
Error::Nom(err) => PyValueError::new_err(format!("failed to parse input: {}", err)),
}
}
fn convert_jaspar(record: Result<lightmotif_io::jaspar::Record, Error>) -> PyResult<Motif> {
let record = record.unwrap();
let record = record.map_err(convert_error)?;
let name = record.id().to_string();
let counts = record.into();
Python::with_gil(|py| Motif::from_counts(py, counts))
let mut motif = Python::with_gil(|py| Motif::from_counts(py, counts))?;
motif.name = Some(name);
Ok(motif)
}
fn convert_jaspar16<A>(record: Result<lightmotif_io::jaspar16::Record<A>, Error>) -> PyResult<Motif>
......@@ -24,10 +130,9 @@ where
WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
{
let record = record.unwrap();
let record = record.map_err(convert_error)?;
let name = record.id().to_string();
let counts = record.into_matrix();
let mut motif = Python::with_gil(|py| Motif::from_counts(py, counts))?;
motif.name = Some(name);
Ok(motif)
......@@ -40,7 +145,7 @@ where
WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
{
let record = record.unwrap();
let record = record.map_err(convert_error)?;
let name = record.accession().or(record.id()).map(String::from);
let counts = record
.to_counts()
......@@ -56,7 +161,7 @@ where
WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>,
ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
{
let record = record.unwrap();
let record = record.map_err(convert_error)?;
let name = record.id().to_string();
let freqs = record.into_matrix();
let weights = freqs.to_weight(None);
......@@ -81,12 +186,19 @@ impl Loader {
}
}
/// Load the motifs contained in a file.
#[pyfunction]
#[pyo3(signature = (path, format="jaspar", protein=false))]
pub fn load(path: &str, format: &str, protein: bool) -> PyResult<Loader> {
let file = std::fs::File::open(path)
.map(std::io::BufReader::new)
.unwrap();
#[pyo3(signature = (file, format="jaspar", *, protein=false))]
pub fn load(file: Bound<PyAny>, format: &str, protein: bool) -> PyResult<Loader> {
let b: Box<dyn BufRead + Send> = if let Ok(s) = file.downcast::<PyString>() {
std::fs::File::open(s.to_str()?)
.map(std::io::BufReader::new)
.map(Box::new)?
} else {
pyfile::PyFileRead::from_ref(&file)
.map(std::io::BufReader::new)
.map(Box::new)?
};
let reader: Box<dyn Iterator<Item = PyResult<Motif>> + Send> = match format {
"jaspar" if protein => {
return Err(PyValueError::new_err(
......@@ -94,23 +206,23 @@ pub fn load(path: &str, format: &str, protein: bool) -> PyResult<Loader> {
))
}
"jaspar16" if protein => Box::new(
lightmotif_io::jaspar16::read::<_, Protein>(file).map(convert_jaspar16::<Protein>),
lightmotif_io::jaspar16::read::<_, Protein>(b).map(convert_jaspar16::<Protein>),
),
"transfac" if protein => Box::new(
lightmotif_io::transfac::read::<_, Protein>(file).map(convert_transfac::<Protein>),
lightmotif_io::transfac::read::<_, Protein>(b).map(convert_transfac::<Protein>),
),
"uniprobe" if protein => Box::new(
lightmotif_io::uniprobe::read::<_, Protein>(file).map(convert_uniprobe::<Protein>),
lightmotif_io::uniprobe::read::<_, Protein>(b).map(convert_uniprobe::<Protein>),
),
"jaspar" => Box::new(lightmotif_io::jaspar::read(file).map(convert_jaspar)),
"jaspar" => Box::new(lightmotif_io::jaspar::read(b).map(convert_jaspar)),
"jaspar16" => {
Box::new(lightmotif_io::jaspar16::read::<_, Dna>(file).map(convert_jaspar16::<Dna>))
Box::new(lightmotif_io::jaspar16::read::<_, Dna>(b).map(convert_jaspar16::<Dna>))
}
"transfac" => {
Box::new(lightmotif_io::transfac::read::<_, Dna>(file).map(convert_transfac::<Dna>))
Box::new(lightmotif_io::transfac::read::<_, Dna>(b).map(convert_transfac::<Dna>))
}
"uniprobe" => {
Box::new(lightmotif_io::uniprobe::read::<_, Dna>(file).map(convert_uniprobe::<Dna>))
Box::new(lightmotif_io::uniprobe::read::<_, Dna>(b).map(convert_uniprobe::<Dna>))
}
_ => return Err(PyValueError::new_err(format!("invalid format: {}", format))),
};
......
......@@ -4,6 +4,7 @@ from . import (
test_scanner,
test_sequence,
test_pvalue,
test_load
)
......@@ -13,4 +14,5 @@ def load_tests(loader, suite, pattern):
suite.addTests(loader.loadTestsFromModule(test_pvalue))
suite.addTests(loader.loadTestsFromModule(test_scanner))
suite.addTests(loader.loadTestsFromModule(test_sequence))
suite.addTests(loader.loadTestsFromModule(test_load))
return suite
import unittest
import io
import textwrap
import tempfile
import lightmotif
class _TestLoad():
def test_load_file(self):
text = textwrap.dedent(self.text).encode()
buffer = io.BytesIO(text)
motifs = list(lightmotif.load(buffer, self.format))
self.assertEqual(len(motifs), self.length)
self.assertEqual(motifs[0].name, self.first)
def test_load_filename(self):
text = textwrap.dedent(self.text).encode()
with tempfile.NamedTemporaryFile("r+b") as f:
f.write(text)
f.flush()
motifs = list(lightmotif.load(f.name, self.format))
self.assertEqual(len(motifs), self.length)
self.assertEqual(motifs[0].name, self.first)
class TestJASPAR(_TestLoad, unittest.TestCase):
format = "jaspar"
length = 1
first = "MA0001.3"
text = """
>MA0001.3 AGL3
0 0 82 40 56 35 65 25 64 0
92 79 1 4 0 0 1 4 0 0
0 0 2 3 1 0 4 3 28 92
3 16 10 48 38 60 25 63 3 3
"""
class TestJASPAR16(_TestLoad, unittest.TestCase):
format = "jaspar16"
length = 2
first = "MA0001.3"
text = """
>MA0001.3 AGL3
A [ 0 0 82 40 56 35 65 25 64 0 ]
C [ 92 79 1 4 0 0 1 4 0 0 ]
G [ 0 0 2 3 1 0 4 3 28 92 ]
T [ 3 16 10 48 38 60 25 63 3 3 ]
>MA0017.3 NR2F1
A [ 7266 6333 8496 0 0 0 0 12059 5116 3229 3276 5681 ]
C [ 1692 387 0 30 0 0 12059 0 3055 2966 2470 2912 ]
G [ 1153 4869 3791 12059 12059 0 0 91 1618 4395 3886 36863 ]
T [ 1948 469 0 0 0 12059 0 0 2270 1469 2427 3466 ]
"""
class TestTRANSFAC(_TestLoad, unittest.TestCase):
format = "transfac"
length = 1
first = "M00005"
text = """
AC M00005
P0 A C G T
01 3 0 0 2 W
02 1 1 3 0 G
03 3 1 1 0 A
04 2 1 2 0 R
05 1 2 0 2 Y
06 0 5 0 0 C
07 5 0 0 0 A
08 0 0 5 0 G
09 0 5 0 0 C
10 0 0 1 4 T
11 0 1 4 0 G
12 0 2 1 2 Y
13 1 0 3 1 G
14 0 0 5 0 G
15 1 1 1 2 N
16 1 4 0 0 C
17 2 1 1 1 N
18 0 0 3 2 K
//
""".lstrip("\n")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment