diff --git a/lightmotif-io/src/jaspar/mod.rs b/lightmotif-io/src/jaspar/mod.rs index 4c69a67963f9777d872a5aab2c406ea5824a1eef..98168f75151ba33794e825163df5631151593a16 100644 --- a/lightmotif-io/src/jaspar/mod.rs +++ b/lightmotif-io/src/jaspar/mod.rs @@ -60,6 +60,12 @@ impl AsRef<CountMatrix<Dna>> for Record { } } +impl From<Record> for CountMatrix<Dna> { + fn from(value: Record) -> Self { + value.matrix + } +} + // --- /// An iterative reader for the JASPAR format. diff --git a/lightmotif-io/src/jaspar16/mod.rs b/lightmotif-io/src/jaspar16/mod.rs index e1ec1d54d3acb118d54a2a65b4e35cd4922c241d..d302ab705529ae3b70553a4d168695d3bd8275d3 100644 --- a/lightmotif-io/src/jaspar16/mod.rs +++ b/lightmotif-io/src/jaspar16/mod.rs @@ -49,6 +49,11 @@ impl<A: Alphabet> Record<A> { pub fn matrix(&self) -> &CountMatrix<A> { &self.matrix } + + /// Take the count matrix of the record. + pub fn into_matrix(self) -> CountMatrix<A> { + self.matrix + } } impl<A: Alphabet> AsRef<CountMatrix<A>> for Record<A> { diff --git a/lightmotif-io/src/uniprobe/mod.rs b/lightmotif-io/src/uniprobe/mod.rs index 82546262df6d1d28fa502fff10bf488fa860f7a3..7a0797c8180b132dbabc65f8c599b679b952a9e1 100644 --- a/lightmotif-io/src/uniprobe/mod.rs +++ b/lightmotif-io/src/uniprobe/mod.rs @@ -39,6 +39,11 @@ impl<A: Alphabet> Record<A> { pub fn matrix(&self) -> &FrequencyMatrix<A> { &self.matrix } + + /// Take the frequency matrix of the record. + pub fn into_matrix(self) -> FrequencyMatrix<A> { + self.matrix + } } impl<A: Alphabet> AsRef<FrequencyMatrix<A>> for Record<A> { diff --git a/lightmotif-py/Cargo.toml b/lightmotif-py/Cargo.toml index c18617bc1e41827366d66e17581d5ae580159961..59d9bf3c21f0227666887deb8c6335c0a710aa1b 100644 --- a/lightmotif-py/Cargo.toml +++ b/lightmotif-py/Cargo.toml @@ -19,6 +19,9 @@ doctest = false [dependencies.lightmotif] path = "../lightmotif" version = "0.8.0" +[dependencies.lightmotif-io] +path = "../lightmotif-io" +version = "0.8.0" [dependencies.lightmotif-tfmpvalue] optional = true path = "../lightmotif-tfmpvalue" diff --git a/lightmotif-py/lightmotif/io.rs b/lightmotif-py/lightmotif/io.rs new file mode 100644 index 0000000000000000000000000000000000000000..897616992fd7147a41478748b6abacbe5850f88f --- /dev/null +++ b/lightmotif-py/lightmotif/io.rs @@ -0,0 +1,108 @@ +use lightmotif::abc::Alphabet; +use lightmotif::abc::Dna; +use lightmotif::abc::Protein; +use lightmotif_io::error::Error; + +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; + +use super::CountMatrixData; +use super::Motif; +use super::ScoringMatrixData; +use super::WeightMatrixData; + +fn convert_jaspar(record: Result<lightmotif_io::jaspar::Record, Error>) -> PyResult<Motif> { + let record = record.unwrap(); + let counts = record.into(); + Python::with_gil(|py| Motif::from_counts(py, counts)) +} + +fn convert_jaspar16<A>(record: Result<lightmotif_io::jaspar16::Record<A>, Error>) -> PyResult<Motif> +where + A: Alphabet, + CountMatrixData: From<lightmotif::pwm::CountMatrix<A>>, + WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>, + ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>, +{ + let record = record.unwrap(); + let counts = record.into_matrix(); + Python::with_gil(|py| Motif::from_counts(py, counts)) +} + +fn convert_transfac<A>(record: Result<lightmotif_io::transfac::Record<A>, Error>) -> PyResult<Motif> +where + A: Alphabet, + CountMatrixData: From<lightmotif::pwm::CountMatrix<A>>, + WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>, + ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>, +{ + let record = record.unwrap(); + let counts = record + .to_counts() + .ok_or_else(|| PyValueError::new_err("invalid count matrix"))?; + Python::with_gil(|py| Motif::from_counts(py, counts)) +} + +fn convert_uniprobe<A>(record: Result<lightmotif_io::uniprobe::Record<A>, Error>) -> PyResult<Motif> +where + A: Alphabet, + WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>, + ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>, +{ + let record = record.unwrap(); + let freqs = record.into_matrix(); + let weights = freqs.to_weight(None); + Python::with_gil(|py| Motif::from_weights(py, weights)) +} + +#[pyclass(module = "lightmotif.lib")] +pub struct Loader { + reader: Box<dyn Iterator<Item = PyResult<Motif>> + Send>, +} + +#[pymethods] +impl Loader { + fn __iter__(slf: PyRef<Self>) -> PyRef<Self> { + slf + } + + fn __next__(mut slf: PyRefMut<Self>) -> Option<PyResult<Motif>> { + slf.reader.next() + } +} + +#[pyfunction] +#[pyo3(signature = (path, format="jaspar", protein=false))] +pub fn load(path: &str, format: &str, protein: bool) -> PyResult<Loader> { + let file = std::fs::File::open(path) + .map(std::io::BufReader::new) + .unwrap(); + let reader: Box<dyn Iterator<Item = PyResult<Motif>> + Send> = match format { + "jaspar" if protein => { + return Err(PyValueError::new_err( + "cannot read protein motifs from JASPAR format", + )) + } + "jaspar16" if protein => Box::new( + lightmotif_io::jaspar16::read::<_, Protein>(file).map(convert_jaspar16::<Protein>), + ), + "transfac" if protein => Box::new( + lightmotif_io::transfac::read::<_, Protein>(file).map(convert_transfac::<Protein>), + ), + "uniprobe" if protein => Box::new( + lightmotif_io::uniprobe::read::<_, Protein>(file).map(convert_uniprobe::<Protein>), + ), + "jaspar" => Box::new(lightmotif_io::jaspar::read(file).map(convert_jaspar)), + "jaspar16" => { + Box::new(lightmotif_io::jaspar16::read::<_, Dna>(file).map(convert_jaspar16::<Dna>)) + } + "transfac" => { + Box::new(lightmotif_io::transfac::read::<_, Dna>(file).map(convert_transfac::<Dna>)) + } + "uniprobe" => { + Box::new(lightmotif_io::uniprobe::read::<_, Dna>(file).map(convert_uniprobe::<Dna>)) + } + _ => return Err(PyValueError::new_err(format!("invalid format: {}", format))), + }; + Ok(Loader { reader }) +} diff --git a/lightmotif-py/lightmotif/lib.rs b/lightmotif-py/lightmotif/lib.rs index 2ac43b1175f9d635d85db74dd3590f796c19a137..79bc65f24e985e51b06edab688224338d8ef4436 100644 --- a/lightmotif-py/lightmotif/lib.rs +++ b/lightmotif-py/lightmotif/lib.rs @@ -34,6 +34,10 @@ use pyo3::types::PyDict; use pyo3::types::PyList; use pyo3::types::PyString; +mod io; + +// --- Macros ------------------------------------------------------------------ + macro_rules! impl_matrix_methods { ($datatype:ident) => { impl $datatype { @@ -1040,13 +1044,45 @@ impl From<lightmotif::scores::StripedScores<f32, C>> for StripedScores { #[derive(Debug)] pub struct Motif { #[pyo3(get)] - counts: Py<CountMatrix>, + counts: Option<Py<CountMatrix>>, #[pyo3(get)] pwm: Py<WeightMatrix>, #[pyo3(get)] pssm: Py<ScoringMatrix>, } +impl Motif { + fn from_counts<A>(py: Python, counts: lightmotif::pwm::CountMatrix<A>) -> PyResult<Self> + where + A: Alphabet, + CountMatrixData: From<lightmotif::pwm::CountMatrix<A>>, + WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>, + ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>, + { + let weights = counts.to_freq(0.0).to_weight(None); + let scoring = weights.to_scoring(); + Ok(Motif { + counts: Some(Py::new(py, CountMatrix::new(counts))?), + pwm: Py::new(py, WeightMatrix::new(weights))?, + pssm: Py::new(py, ScoringMatrix::new(scoring))?, + }) + } + + fn from_weights<A>(py: Python, weights: lightmotif::pwm::WeightMatrix<A>) -> PyResult<Self> + where + A: Alphabet, + WeightMatrixData: From<lightmotif::pwm::WeightMatrix<A>>, + ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>, + { + let scoring = weights.to_scoring(); + Ok(Motif { + counts: None, + pwm: Py::new(py, WeightMatrix::new(weights))?, + pssm: Py::new(py, ScoringMatrix::new(scoring))?, + }) + } +} + // --- Scanner ----------------------------------------------------------------- #[derive(Debug)] @@ -1139,7 +1175,7 @@ pub fn create(sequences: Bound<PyAny>, protein: bool) -> PyResult<Motif> { let scoring = weights.to_scoring(); Ok(Motif { - counts: Py::new(py, CountMatrix::new(data))?, + counts: Some(Py::new(py, CountMatrix::new(data))?), pwm: Py::new(py, WeightMatrix::new(weights))?, pssm: Py::new(py, ScoringMatrix::new(scoring))?, }) @@ -1251,9 +1287,12 @@ pub fn init<'py>(_py: Python<'py>, m: &Bound<PyModule>) -> PyResult<()> { m.add_class::<Scanner>()?; m.add_class::<Hit>()?; + m.add_class::<io::Loader>()?; + m.add_function(wrap_pyfunction!(create, m)?)?; m.add_function(wrap_pyfunction!(scan, m)?)?; m.add_function(wrap_pyfunction!(stripe, m)?)?; + m.add_function(wrap_pyfunction!(io::load, m)?)?; Ok(()) }