Skip to content
Snippets Groups Projects
Commit 6b551229 authored by Martin Larralde's avatar Martin Larralde
Browse files

Implement buffer protocol for `ScoringMatrix`

parent 5c36d18f
No related branches found
No related tags found
No related merge requests found
......@@ -34,6 +34,36 @@ use pyo3::types::PyDict;
use pyo3::types::PyList;
use pyo3::types::PyString;
macro_rules! impl_matrix_methods {
($datatype:ident) => {
impl $datatype {
#[allow(unused)]
fn rows(&self) -> usize {
match self {
Self::Dna(dna) => dna.matrix().rows(),
Self::Protein(protein) => protein.matrix().rows(),
}
}
#[allow(unused)]
fn columns(&self) -> usize {
match self {
Self::Dna(dna) => dna.matrix().columns(),
Self::Protein(protein) => protein.matrix().columns(),
}
}
#[allow(unused)]
fn stride(&self) -> usize {
match self {
Self::Dna(dna) => dna.matrix().stride(),
Self::Protein(protein) => protein.matrix().stride(),
}
}
}
};
}
// --- Compile-time constants --------------------------------------------------
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
......@@ -250,28 +280,7 @@ enum StripedSequenceData {
Protein(lightmotif::seq::StripedSequence<Protein, C>),
}
impl StripedSequenceData {
fn rows(&self) -> usize {
match self {
Self::Dna(dna) => dna.matrix().rows(),
Self::Protein(protein) => protein.matrix().rows(),
}
}
fn columns(&self) -> usize {
match self {
Self::Dna(dna) => dna.matrix().columns(),
Self::Protein(protein) => protein.matrix().columns(),
}
}
fn stride(&self) -> usize {
match self {
Self::Dna(dna) => dna.matrix().stride(),
Self::Protein(protein) => protein.matrix().stride(),
}
}
}
impl_matrix_methods!(StripedSequenceData);
impl From<lightmotif::seq::StripedSequence<Dna, C>> for StripedSequenceData {
fn from(dna: lightmotif::seq::StripedSequence<Dna, C>) -> Self {
......@@ -296,11 +305,9 @@ pub struct StripedSequence {
impl From<StripedSequenceData> for StripedSequence {
fn from(data: StripedSequenceData) -> Self {
// extract the matrix shape and strides
let cols = data.columns();
let rows = data.rows();
let shape = [cols as Py_ssize_t, rows as Py_ssize_t];
// extract the matrix strides
let strides = [1, data.stride() as Py_ssize_t];
Self {
data,
......@@ -368,14 +375,7 @@ pub enum CountMatrixData {
Protein(lightmotif::CountMatrix<Protein>),
}
impl CountMatrixData {
fn rows(&self) -> usize {
match self {
Self::Dna(dna) => dna.matrix().rows(),
Self::Protein(protein) => protein.matrix().rows(),
}
}
}
impl_matrix_methods!(CountMatrixData);
impl From<lightmotif::CountMatrix<Dna>> for CountMatrixData {
fn from(data: lightmotif::CountMatrix<Dna>) -> Self {
......@@ -543,14 +543,7 @@ pub enum WeightMatrixData {
Protein(lightmotif::WeightMatrix<Protein>),
}
impl WeightMatrixData {
fn rows(&self) -> usize {
match self {
Self::Dna(dna) => dna.matrix().rows(),
Self::Protein(protein) => protein.matrix().rows(),
}
}
}
impl_matrix_methods!(WeightMatrixData);
impl From<lightmotif::WeightMatrix<Dna>> for WeightMatrixData {
fn from(data: lightmotif::WeightMatrix<Dna>) -> Self {
......@@ -655,7 +648,7 @@ impl WeightMatrix {
impl From<WeightMatrixData> for WeightMatrix {
fn from(data: WeightMatrixData) -> Self {
Self { data }
Self::new(data)
}
}
......@@ -668,14 +661,7 @@ enum ScoringMatrixData {
Protein(lightmotif::ScoringMatrix<Protein>),
}
impl ScoringMatrixData {
fn rows(&self) -> usize {
match self {
Self::Dna(dna) => dna.matrix().rows(),
Self::Protein(prot) => prot.matrix().rows(),
}
}
}
impl_matrix_methods!(ScoringMatrixData);
impl From<lightmotif::ScoringMatrix<Dna>> for ScoringMatrixData {
fn from(value: lightmotif::ScoringMatrix<Dna>) -> Self {
......@@ -694,6 +680,8 @@ impl From<lightmotif::ScoringMatrix<Protein>> for ScoringMatrixData {
#[derive(Clone, Debug)]
pub struct ScoringMatrix {
data: ScoringMatrixData,
shape: [Py_ssize_t; 2],
strides: [Py_ssize_t; 2],
}
impl ScoringMatrix {
......@@ -701,7 +689,20 @@ impl ScoringMatrix {
where
D: Into<ScoringMatrixData>,
{
Self { data: data.into() }
let data = data.into();
let cols = data.columns();
let rows = data.rows();
let stride = data.stride();
let shape = [cols as Py_ssize_t, rows as Py_ssize_t];
let strides = [
(stride * std::mem::size_of::<f32>()) as Py_ssize_t,
std::mem::size_of::<f32>() as Py_ssize_t,
];
Self {
data,
shape,
strides,
}
}
}
......@@ -786,6 +787,43 @@ impl ScoringMatrix {
Ok(row)
}
unsafe fn __getbuffer__(
mut slf: PyRefMut<'_, Self>,
view: *mut pyo3::ffi::Py_buffer,
flags: std::os::raw::c_int,
) -> PyResult<()> {
if view.is_null() {
return Err(PyBufferError::new_err("View is null"));
}
if (flags & pyo3::ffi::PyBUF_WRITABLE) == pyo3::ffi::PyBUF_WRITABLE {
return Err(PyBufferError::new_err("Object is not writable"));
}
(*view).obj = pyo3::ffi::_Py_NewRef(slf.as_ptr());
let data = match &slf.data {
ScoringMatrixData::Dna(dna) => dna.matrix()[0].as_ptr() as *const f32,
ScoringMatrixData::Protein(prot) => prot.matrix()[0].as_ptr() as *const f32,
};
(*view).buf = data as *mut std::os::raw::c_void;
(*view).len = -1;
(*view).readonly = 1;
(*view).itemsize = std::mem::size_of::<f32>() as isize;
let msg = std::ffi::CStr::from_bytes_with_nul(b"f\0").unwrap();
(*view).format = msg.as_ptr() as *mut _;
(*view).ndim = 2;
(*view).shape = slf.shape.as_mut_ptr();
(*view).strides = slf.strides.as_mut_ptr();
(*view).suboffsets = std::ptr::null_mut();
(*view).internal = std::ptr::null_mut();
Ok(())
}
/// Calculate the PSSM score for all positions of the given sequence.
///
/// Returns:
......@@ -859,7 +897,7 @@ impl ScoringMatrix {
impl From<ScoringMatrixData> for ScoringMatrix {
fn from(data: ScoringMatrixData) -> Self {
Self { data }
Self::new(data)
}
}
......@@ -905,7 +943,7 @@ impl StripedScores {
let data = slf.scores.matrix()[0].as_ptr();
(*view).buf = data as *mut std::os::raw::c_void;
(*view).len = (slf.scores.matrix().rows() * slf.scores.matrix().columns()) as isize;
(*view).len = -1;
(*view).readonly = 1;
(*view).itemsize = std::mem::size_of::<f32>() as isize;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment