diff --git a/lightmotif-py/lightmotif/io.rs b/lightmotif-py/lightmotif/io.rs index 6d18d4807794f922c234fdfbc1636f419db973bc..af7b0d2532d265b0a2f677662db847b5ba3b39f3 100644 --- a/lightmotif-py/lightmotif/io.rs +++ b/lightmotif-py/lightmotif/io.rs @@ -9,6 +9,7 @@ use lightmotif_io::error::Error; use pyo3::exceptions::PyOSError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use pyo3::types::PyBytes; use pyo3::types::PyString; use super::pyfile::PyFileRead; @@ -160,8 +161,18 @@ impl Loader { #[pyfunction] #[pyo3(signature = (file, format="jaspar", *, protein=false))] pub fn load(file: Bound<PyAny>, format: &str, protein: bool) -> PyResult<Loader> { - let b: Box<dyn BufRead + Send> = if let Ok(s) = file.downcast::<PyString>() { - std::fs::File::open(s.to_str()?) + let py = file.py(); + let pathlike = py + .import_bound(pyo3::intern!(py, "os"))? + .call_method1("fsencode", (&file,)); + let b: Box<dyn BufRead + Send> = if let Ok(path) = pathlike { + // NOTE(@althonos): In theory this is safe because `os.fsencode` encodes + // the PathLike object into the OS prefered encoding, + // which is was OsStr wants. In practice, there may be + // some weird bugs if that encoding is incorrect, idk... + let encoded = path.downcast::<PyBytes>()?; + let s = unsafe { std::ffi::OsStr::from_encoded_bytes_unchecked(encoded.as_bytes()) }; + std::fs::File::open(s) .map(std::io::BufReader::new) .map(Box::new)? } else { diff --git a/lightmotif-py/lightmotif/tests/test_load.py b/lightmotif-py/lightmotif/tests/test_load.py index cb6f4b5189f78a325929616c49b140ff334f85a1..fae35ad010aa79a9a91dec25b0be131560e06bec 100644 --- a/lightmotif-py/lightmotif/tests/test_load.py +++ b/lightmotif-py/lightmotif/tests/test_load.py @@ -1,7 +1,9 @@ import unittest import io +import os import textwrap import tempfile +import pathlib import lightmotif @@ -24,6 +26,25 @@ class _TestLoad(): self.assertEqual(len(motifs), self.length) self.assertEqual(motifs[0].name, self.first) + def test_load_filename_bytes(self): + text = textwrap.dedent(self.text).encode() + with tempfile.NamedTemporaryFile("r+b") as f: + f.write(text) + f.flush() + motifs = list(lightmotif.load(os.fsencode(f.name), self.format)) + self.assertEqual(len(motifs), self.length) + self.assertEqual(motifs[0].name, self.first) + + def test_load_path(self): + text = textwrap.dedent(self.text).encode() + with tempfile.NamedTemporaryFile("r+b") as f: + f.write(text) + f.flush() + path = pathlib.Path(f.name) + motifs = list(lightmotif.load(path, self.format)) + self.assertEqual(len(motifs), self.length) + self.assertEqual(motifs[0].name, self.first) + class TestJASPAR(_TestLoad, unittest.TestCase): format = "jaspar"