From f6950e9b9ad2c23d29c6a0f93e1da98a20ceb6d3 Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Mon, 2 Sep 2024 13:08:55 +0200 Subject: [PATCH] Implement `Scanner.__init__` and fix type annotations in Python package --- lightmotif-py/lightmotif/lib.pyi | 11 +++++ lightmotif-py/lightmotif/lib.rs | 77 +++++++++++++++++++------------- 2 files changed, 57 insertions(+), 31 deletions(-) diff --git a/lightmotif-py/lightmotif/lib.pyi b/lightmotif-py/lightmotif/lib.pyi index 41cc93c..91cd2e3 100644 --- a/lightmotif-py/lightmotif/lib.pyi +++ b/lightmotif-py/lightmotif/lib.pyi @@ -84,12 +84,23 @@ class TransfacMotif(Motif): class JasparMotif(Motif): @property def counts(self) -> CountMatrix: ... + @property + def name(self) -> str: ... + @property + def description(self) -> Optional[str]: ... class UniprobeMotif(Motif): @property def counts(self) -> None: ... class Scanner(Iterator[Hit]): + def __init__( + self, + pssm: ScoringMatrix, + sequence: StripedSequence, + threshold: float = 0.0, + block_size: int = 256, + ) -> None: ... def __iter__(self) -> Scanner: ... def __next__(self) -> Hit: ... diff --git a/lightmotif-py/lightmotif/lib.rs b/lightmotif-py/lightmotif/lib.rs index 6eedc89..98b4d19 100644 --- a/lightmotif-py/lightmotif/lib.rs +++ b/lightmotif-py/lightmotif/lib.rs @@ -312,7 +312,10 @@ impl From<StripedSequenceData> for StripedSequence { let cols = data.columns(); let rows = data.rows(); let shape = [cols as Py_ssize_t, rows as Py_ssize_t]; - let strides = [1, data.stride() as Py_ssize_t]; + let strides = [ + std::mem::size_of::<u8>() as Py_ssize_t, + (std::mem::size_of::<u8>() * data.stride()) as Py_ssize_t, + ]; Self { data, shape, @@ -1107,6 +1110,44 @@ pub struct Scanner { #[pymethods] impl Scanner { + #[new] + #[pyo3(signature = (pssm, sequence, threshold = 0.0, block_size = 256))] + fn __init__<'py>( + pssm: Bound<'py, ScoringMatrix>, + sequence: Bound<'py, StripedSequence>, + threshold: f32, + block_size: usize, + ) -> PyResult<PyClassInitializer<Self>> { + match ( + &pssm.try_borrow()?.data, + &mut sequence.try_borrow_mut()?.data, + ) { + (ScoringMatrixData::Dna(p), StripedSequenceData::Dna(s)) => { + s.configure(&p); + // transmute (!!!!!) the scanner so that its lifetime is 'static + // (i.e. the reference to the PSSM and sequence never expire), + // which is possible because we are building a self-referential + // struct and self.scanner.next() will only be called with + // the GIL held... + let scanner = unsafe { + let mut scanner = lightmotif::scan::Scanner::<Dna, _, _, C>::new(p, s); + scanner.threshold(threshold); + scanner.block_size(block_size); + std::mem::transmute(scanner) + }; + Ok(PyClassInitializer::from(Scanner { + data: scanner, + pssm: pssm.unbind(), + sequence: sequence.unbind(), + })) + } + (ScoringMatrixData::Protein(_), StripedSequenceData::Protein(_)) => { + Err(PyValueError::new_err("protein scanner is not supported")) + } + (_, _) => Err(PyValueError::new_err("alphabet mismatch")), + } + } + fn __iter__(slf: PyRef<Self>) -> PyRef<Self> { slf } @@ -1216,36 +1257,10 @@ pub fn scan<'py>( block_size: usize, ) -> PyResult<Bound<'py, Scanner>> { let py = pssm.py(); - match ( - &pssm.try_borrow()?.data, - &mut sequence.try_borrow_mut()?.data, - ) { - (ScoringMatrixData::Dna(p), StripedSequenceData::Dna(s)) => { - s.configure(&p); - // transmute (!!!!!) the scanner so that its lifetime is 'static - // (i.e. the reference to the PSSM and sequence never expire), - // which is possible because we are building a self-referential - // struct - let scanner = unsafe { - let mut scanner = lightmotif::scan::Scanner::<Dna, _, _, C>::new(p, s); - scanner.threshold(threshold); - scanner.block_size(block_size); - std::mem::transmute(scanner) - }; - Bound::new( - py, - Scanner { - data: scanner, - pssm: pssm.unbind(), - sequence: sequence.unbind(), - }, - ) - } - (ScoringMatrixData::Protein(_), StripedSequenceData::Protein(_)) => { - Err(PyValueError::new_err("protein scanner is not supported")) - } - (_, _) => Err(PyValueError::new_err("alphabet mismatch")), - } + Bound::new( + py, + Scanner::__init__(pssm, sequence, threshold, block_size)?, + ) } // --- Module ------------------------------------------------------------------ -- GitLab