diff --git a/lightmotif-py/lightmotif/lib.rs b/lightmotif-py/lightmotif/lib.rs index 3fd27df19e52c0c7125120b4fc66de9c5e4d14c1..2ac43b1175f9d635d85db74dd3590f796c19a137 100644 --- a/lightmotif-py/lightmotif/lib.rs +++ b/lightmotif-py/lightmotif/lib.rs @@ -34,6 +34,36 @@ use pyo3::types::PyDict; use pyo3::types::PyList; use pyo3::types::PyString; +macro_rules! impl_matrix_methods { + ($datatype:ident) => { + impl $datatype { + #[allow(unused)] + fn rows(&self) -> usize { + match self { + Self::Dna(dna) => dna.matrix().rows(), + Self::Protein(protein) => protein.matrix().rows(), + } + } + + #[allow(unused)] + fn columns(&self) -> usize { + match self { + Self::Dna(dna) => dna.matrix().columns(), + Self::Protein(protein) => protein.matrix().columns(), + } + } + + #[allow(unused)] + fn stride(&self) -> usize { + match self { + Self::Dna(dna) => dna.matrix().stride(), + Self::Protein(protein) => protein.matrix().stride(), + } + } + } + }; +} + // --- Compile-time constants -------------------------------------------------- #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] @@ -250,28 +280,7 @@ enum StripedSequenceData { Protein(lightmotif::seq::StripedSequence<Protein, C>), } -impl StripedSequenceData { - fn rows(&self) -> usize { - match self { - Self::Dna(dna) => dna.matrix().rows(), - Self::Protein(protein) => protein.matrix().rows(), - } - } - - fn columns(&self) -> usize { - match self { - Self::Dna(dna) => dna.matrix().columns(), - Self::Protein(protein) => protein.matrix().columns(), - } - } - - fn stride(&self) -> usize { - match self { - Self::Dna(dna) => dna.matrix().stride(), - Self::Protein(protein) => protein.matrix().stride(), - } - } -} +impl_matrix_methods!(StripedSequenceData); impl From<lightmotif::seq::StripedSequence<Dna, C>> for StripedSequenceData { fn from(dna: lightmotif::seq::StripedSequence<Dna, C>) -> Self { @@ -296,11 +305,9 @@ pub struct StripedSequence { impl From<StripedSequenceData> for StripedSequence { fn from(data: StripedSequenceData) -> Self { - // extract the matrix shape and strides let cols = data.columns(); let rows = data.rows(); let shape = [cols as Py_ssize_t, rows as Py_ssize_t]; - // extract the matrix strides let strides = [1, data.stride() as Py_ssize_t]; Self { data, @@ -368,14 +375,7 @@ pub enum CountMatrixData { Protein(lightmotif::CountMatrix<Protein>), } -impl CountMatrixData { - fn rows(&self) -> usize { - match self { - Self::Dna(dna) => dna.matrix().rows(), - Self::Protein(protein) => protein.matrix().rows(), - } - } -} +impl_matrix_methods!(CountMatrixData); impl From<lightmotif::CountMatrix<Dna>> for CountMatrixData { fn from(data: lightmotif::CountMatrix<Dna>) -> Self { @@ -543,14 +543,7 @@ pub enum WeightMatrixData { Protein(lightmotif::WeightMatrix<Protein>), } -impl WeightMatrixData { - fn rows(&self) -> usize { - match self { - Self::Dna(dna) => dna.matrix().rows(), - Self::Protein(protein) => protein.matrix().rows(), - } - } -} +impl_matrix_methods!(WeightMatrixData); impl From<lightmotif::WeightMatrix<Dna>> for WeightMatrixData { fn from(data: lightmotif::WeightMatrix<Dna>) -> Self { @@ -655,7 +648,7 @@ impl WeightMatrix { impl From<WeightMatrixData> for WeightMatrix { fn from(data: WeightMatrixData) -> Self { - Self { data } + Self::new(data) } } @@ -668,14 +661,7 @@ enum ScoringMatrixData { Protein(lightmotif::ScoringMatrix<Protein>), } -impl ScoringMatrixData { - fn rows(&self) -> usize { - match self { - Self::Dna(dna) => dna.matrix().rows(), - Self::Protein(prot) => prot.matrix().rows(), - } - } -} +impl_matrix_methods!(ScoringMatrixData); impl From<lightmotif::ScoringMatrix<Dna>> for ScoringMatrixData { fn from(value: lightmotif::ScoringMatrix<Dna>) -> Self { @@ -694,6 +680,8 @@ impl From<lightmotif::ScoringMatrix<Protein>> for ScoringMatrixData { #[derive(Clone, Debug)] pub struct ScoringMatrix { data: ScoringMatrixData, + shape: [Py_ssize_t; 2], + strides: [Py_ssize_t; 2], } impl ScoringMatrix { @@ -701,7 +689,20 @@ impl ScoringMatrix { where D: Into<ScoringMatrixData>, { - Self { data: data.into() } + let data = data.into(); + let cols = data.columns(); + let rows = data.rows(); + let stride = data.stride(); + let shape = [cols as Py_ssize_t, rows as Py_ssize_t]; + let strides = [ + (stride * std::mem::size_of::<f32>()) as Py_ssize_t, + std::mem::size_of::<f32>() as Py_ssize_t, + ]; + Self { + data, + shape, + strides, + } } } @@ -786,6 +787,43 @@ impl ScoringMatrix { Ok(row) } + unsafe fn __getbuffer__( + mut slf: PyRefMut<'_, Self>, + view: *mut pyo3::ffi::Py_buffer, + flags: std::os::raw::c_int, + ) -> PyResult<()> { + if view.is_null() { + return Err(PyBufferError::new_err("View is null")); + } + if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE { + return Err(PyBufferError::new_err("Object is not writable")); + } + + (*view).obj = pyo3::ffi::_Py_NewRef(slf.as_ptr()); + + let data = match &slf.data { + ScoringMatrixData::Dna(dna) => dna.matrix()[0].as_ptr() as *const f32, + ScoringMatrixData::Protein(prot) => prot.matrix()[0].as_ptr() as *const f32, + }; + + (*view).buf = data as *mut std::os::raw::c_void; + (*view).len = -1; + (*view).readonly = 1; + (*view).itemsize = std::mem::size_of::<f32>() as isize; + + let msg = std::ffi::CStr::from_bytes_with_nul(b"f\0").unwrap(); + (*view).format = msg.as_ptr() as *mut _; + + (*view).ndim = 2; + (*view).shape = slf.shape.as_mut_ptr(); + (*view).strides = slf.strides.as_mut_ptr(); + + (*view).suboffsets = std::ptr::null_mut(); + (*view).internal = std::ptr::null_mut(); + + Ok(()) + } + /// Calculate the PSSM score for all positions of the given sequence. /// /// Returns: @@ -859,7 +897,7 @@ impl ScoringMatrix { impl From<ScoringMatrixData> for ScoringMatrix { fn from(data: ScoringMatrixData) -> Self { - Self { data } + Self::new(data) } } @@ -905,7 +943,7 @@ impl StripedScores { let data = slf.scores.matrix()[0].as_ptr(); (*view).buf = data as *mut std::os::raw::c_void; - (*view).len = (slf.scores.matrix().rows() * slf.scores.matrix().columns()) as isize; + (*view).len = -1; (*view).readonly = 1; (*view).itemsize = std::mem::size_of::<f32>() as isize;