diff --git a/lightmotif-py/lightmotif/lib.rs b/lightmotif-py/lightmotif/lib.rs index 8d6c72310321840e58314dac2bc7e4dd4f9e2f3d..959abda3525d54204df216761fa67e74b7ad665f 100644 --- a/lightmotif-py/lightmotif/lib.rs +++ b/lightmotif-py/lightmotif/lib.rs @@ -6,9 +6,11 @@ extern crate lightmotif; extern crate lightmotif_tfmpvalue; extern crate pyo3; +use std::borrow::Borrow; use std::fmt::Display; use std::fmt::Formatter; use std::pin::Pin; +use std::sync::Arc; use lightmotif::abc::Alphabet; use lightmotif::abc::Dna; @@ -1089,18 +1091,20 @@ impl Motif { // --- Scanner ----------------------------------------------------------------- -#[derive(Debug)] -enum ScannerData<'py> { - Dna(lightmotif::scan::Scanner<'py, Dna, C>), - // Protein(lightmotif::scan::Scanner::<'py, Protein, C>), -} - #[pyclass(module = "lightmotif.lib")] #[derive(Debug)] pub struct Scanner { - pssm: Pin<Box<lightmotif::pwm::ScoringMatrix<Dna>>>, - sequence: Pin<Box<lightmotif::seq::StripedSequence<Dna, C>>>, - data: lightmotif::scan::Scanner<'static, Dna, C>, + #[allow(unused)] + pssm: Py<ScoringMatrix>, + #[allow(unused)] + sequence: Py<StripedSequence>, + data: lightmotif::scan::Scanner< + 'static, + Dna, + &'static lightmotif::pwm::ScoringMatrix<Dna>, + &'static lightmotif::seq::StripedSequence<Dna, C>, + C, + >, } #[pymethods] @@ -1208,35 +1212,37 @@ pub fn stripe(sequence: Bound<PyAny>, protein: bool) -> PyResult<StripedSequence /// Scan a sequence using a fast scanner implementation to identify hits. #[pyfunction] #[pyo3(signature = (pssm, sequence, threshold = 0.0, block_size = 256))] -pub fn scan( - pssm: Bound<ScoringMatrix>, - sequence: Bound<StripedSequence>, +pub fn scan<'py>( + pssm: Bound<'py, ScoringMatrix>, + sequence: Bound<'py, StripedSequence>, threshold: f32, block_size: usize, -) -> PyResult<Scanner> { - match (&pssm.try_borrow()?.data, &sequence.try_borrow()?.data) { +) -> PyResult<Bound<'py, Scanner>> { + let py = pssm.py(); + match ( + &pssm.try_borrow()?.data, + &mut sequence.try_borrow_mut()?.data, + ) { (ScoringMatrixData::Dna(p), StripedSequenceData::Dna(s)) => { - // Move data to the heap so that it doesn't get unallocated - // while the scanner is in use (the scanner needs to have - // access to the data through a reference) - let p = Box::new(p.clone()); - let mut s = Box::new(s.clone()); s.configure(&p); // transmute (!!!!!) the scanner so that its lifetime is 'static // (i.e. the reference to the PSSM and sequence never expire), // which is possible because we are building a self-referential // struct let scanner = unsafe { - let mut s = lightmotif::scan::Scanner::<Dna, C>::new(&p, &s); - s.threshold(threshold); - s.block_size(block_size); - std::mem::transmute(s) + let mut scanner = lightmotif::scan::Scanner::<Dna, _, _, C>::new(p, s); + scanner.threshold(threshold); + scanner.block_size(block_size); + std::mem::transmute(scanner) }; - Ok(Scanner { - data: scanner, - pssm: Pin::new(p), - sequence: Pin::new(s), - }) + Bound::new( + py, + Scanner { + data: scanner, + pssm: pssm.unbind(), + sequence: sequence.unbind(), + }, + ) } (ScoringMatrixData::Protein(_), StripedSequenceData::Protein(_)) => { Err(PyValueError::new_err("protein scanner is not supported")) diff --git a/lightmotif/src/scan.rs b/lightmotif/src/scan.rs index 86054b3f22b51f54bb27064d81e3779fcb56fbee..d9adc162f68d916a034ac1a1578c4aff68f60732 100644 --- a/lightmotif/src/scan.rs +++ b/lightmotif/src/scan.rs @@ -93,10 +93,16 @@ impl Ord for Hit { /// A scanner for iterating over scoring matrix hits in a sequence. #[derive(Debug)] -pub struct Scanner<'a, A: Alphabet, C: StrictlyPositive = DefaultColumns> { - pssm: &'a ScoringMatrix<A>, +pub struct Scanner< + 'a, + A: Alphabet, + M: AsRef<ScoringMatrix<A>>, + S: AsRef<StripedSequence<A, C>>, + C: StrictlyPositive = DefaultColumns, +> { + pssm: M, dm: DiscreteMatrix<A>, - seq: &'a StripedSequence<A, C>, + seq: S, scores: CowMut<'a, StripedScores<f32, C>>, dscores: StripedScores<u8, C>, threshold: f32, @@ -106,13 +112,17 @@ pub struct Scanner<'a, A: Alphabet, C: StrictlyPositive = DefaultColumns> { pipeline: Pipeline<A, Dispatch>, } -impl<'a, A: Alphabet, C: StrictlyPositive> Scanner<'a, A, C> { +impl<'a, A, M, S, C> Scanner<'a, A, M, S, C> +where + A: Alphabet, + C: StrictlyPositive, + M: AsRef<ScoringMatrix<A>>, + S: AsRef<StripedSequence<A, C>>, +{ /// Create a new scanner for the given matrix and sequence. - pub fn new(pssm: &'a ScoringMatrix<A>, seq: &'a StripedSequence<A, C>) -> Self { + pub fn new(pssm: M, seq: S) -> Self { Self { - pssm, - seq, - dm: pssm.to_discrete(), + dm: pssm.as_ref().to_discrete(), scores: CowMut::Owned(StripedScores::empty()), dscores: StripedScores::empty(), threshold: 0.0, @@ -120,6 +130,8 @@ impl<'a, A: Alphabet, C: StrictlyPositive> Scanner<'a, A, C> { row: 0, hits: Vec::new(), pipeline: Pipeline::dispatch(), + pssm, + seq, } } @@ -142,19 +154,22 @@ impl<'a, A: Alphabet, C: StrictlyPositive> Scanner<'a, A, C> { } } -impl<'a, A, C> Iterator for Scanner<'a, A, C> +impl<'a, A, M, S, C> Iterator for Scanner<'a, A, M, S, C> where A: Alphabet, C: StrictlyPositive, + M: AsRef<ScoringMatrix<A>>, + S: AsRef<StripedSequence<A, C>>, Pipeline<A, Dispatch>: Score<u8, A, C> + Threshold<u8, C> + Maximum<u8, C>, { type Item = Hit; fn next(&mut self) -> Option<Self::Item> { + let seq = self.seq.as_ref(); let t = self.dm.scale(self.threshold); - while self.hits.is_empty() && self.row < self.seq.matrix().rows() { + while self.hits.is_empty() && self.row < seq.matrix().rows() { // compute the row slice to score in the striped sequence matrix - let end = (self.row + self.block_size) - .min(self.seq.matrix().rows().saturating_sub(self.seq.wrap())); + let end = + (self.row + self.block_size).min(seq.matrix().rows().saturating_sub(seq.wrap())); // score the row slice self.pipeline .score_rows_into(&self.dm, &self.seq, self.row..end, &mut self.dscores); @@ -163,9 +178,8 @@ where // scan through the positions above discrete threshold and recompute // scores in floating-point to see if they pass the real threshold. for c in self.pipeline.threshold(&self.dscores, t) { - let index = - c.col * (self.seq.matrix().rows() - self.seq.wrap()) + self.row + c.row; - let score = self.pssm.score_position(&self.seq, index); + let index = c.col * (seq.matrix().rows() - seq.wrap()) + self.row + c.row; + let score = self.pssm.as_ref().score_position(seq, index); if score >= self.threshold { self.hits.push(Hit::new(index, score)); } @@ -178,6 +192,8 @@ where } fn max(mut self) -> Option<Self::Item> { + let seq = self.seq.as_ref(); + // Compute the score of the best hit not yet returned, and translate // the `f32` score threshold into a discrete, under-estimate `u8` // threshold. @@ -190,15 +206,15 @@ where None => self.dm.scale(self.threshold), }; - // Cache th number of sequence rows in the striped sequence matrix. - let sequence_rows = self.seq.matrix().rows() - self.seq.wrap(); + // Cache the number of sequence rows in the striped sequence matrix. + let sequence_rows = seq.matrix().rows() - seq.wrap(); // Process all rows of the sequence and record the local - while self.row < self.seq.matrix().rows() { + while self.row < seq.matrix().rows() { // Score the rows of the current block. let end = (self.row + self.block_size).min(sequence_rows); self.pipeline - .score_rows_into(&self.dm, &self.seq, self.row..end, &mut self.dscores); + .score_rows_into(&self.dm, seq, self.row..end, &mut self.dscores); // Check if the highest score in the block is high enough to be // a new global maximum if self.pipeline.max(&self.dscores).unwrap() >= best_discrete { @@ -208,7 +224,7 @@ where let dscore = self.dscores.matrix()[c]; if dscore >= best_discrete { let index = c.col * sequence_rows + self.row + c.row; - let score = self.pssm.score_position(&self.seq, index); + let score = self.pssm.as_ref().score_position(&self.seq, index); if let Some(hit) = &best { if (score > hit.score) | (score == hit.score && index > hit.position) { best = Some(Hit::new(index, score)); diff --git a/lightmotif/src/seq.rs b/lightmotif/src/seq.rs index 85e4646194c4941c19d653e1755a86d3c5f8862c..0a0347662547ee2f0a6ac4fc28151980c8b16ea4 100644 --- a/lightmotif/src/seq.rs +++ b/lightmotif/src/seq.rs @@ -10,7 +10,6 @@ use std::str::FromStr; use rand::Rng; use super::abc::Alphabet; -use super::abc::Background; use super::abc::Symbol; use super::dense::DenseMatrix; use super::err::InvalidData;