From cbc2ca63d2ec0551ef536de53f3bcd3b968229aa Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Sat, 31 Aug 2024 12:52:52 +0200
Subject: [PATCH] Fix `Scanner` in `lightmotif-py` to avoid copying data

---
 lightmotif-py/lightmotif/lib.rs | 62 ++++++++++++++++++---------------
 lightmotif/src/scan.rs          | 56 ++++++++++++++++++-----------
 lightmotif/src/seq.rs           |  1 -
 3 files changed, 70 insertions(+), 49 deletions(-)

diff --git a/lightmotif-py/lightmotif/lib.rs b/lightmotif-py/lightmotif/lib.rs
index 8d6c723..959abda 100644
--- a/lightmotif-py/lightmotif/lib.rs
+++ b/lightmotif-py/lightmotif/lib.rs
@@ -6,9 +6,11 @@ extern crate lightmotif;
 extern crate lightmotif_tfmpvalue;
 extern crate pyo3;
 
+use std::borrow::Borrow;
 use std::fmt::Display;
 use std::fmt::Formatter;
 use std::pin::Pin;
+use std::sync::Arc;
 
 use lightmotif::abc::Alphabet;
 use lightmotif::abc::Dna;
@@ -1089,18 +1091,20 @@ impl Motif {
 
 // --- Scanner -----------------------------------------------------------------
 
-#[derive(Debug)]
-enum ScannerData<'py> {
-    Dna(lightmotif::scan::Scanner<'py, Dna, C>),
-    // Protein(lightmotif::scan::Scanner::<'py, Protein, C>),
-}
-
 #[pyclass(module = "lightmotif.lib")]
 #[derive(Debug)]
 pub struct Scanner {
-    pssm: Pin<Box<lightmotif::pwm::ScoringMatrix<Dna>>>,
-    sequence: Pin<Box<lightmotif::seq::StripedSequence<Dna, C>>>,
-    data: lightmotif::scan::Scanner<'static, Dna, C>,
+    #[allow(unused)]
+    pssm: Py<ScoringMatrix>,
+    #[allow(unused)]
+    sequence: Py<StripedSequence>,
+    data: lightmotif::scan::Scanner<
+        'static,
+        Dna,
+        &'static lightmotif::pwm::ScoringMatrix<Dna>,
+        &'static lightmotif::seq::StripedSequence<Dna, C>,
+        C,
+    >,
 }
 
 #[pymethods]
@@ -1208,35 +1212,37 @@ pub fn stripe(sequence: Bound<PyAny>, protein: bool) -> PyResult<StripedSequence
 /// Scan a sequence using a fast scanner implementation to identify hits.
 #[pyfunction]
 #[pyo3(signature = (pssm, sequence, threshold = 0.0, block_size = 256))]
-pub fn scan(
-    pssm: Bound<ScoringMatrix>,
-    sequence: Bound<StripedSequence>,
+pub fn scan<'py>(
+    pssm: Bound<'py, ScoringMatrix>,
+    sequence: Bound<'py, StripedSequence>,
     threshold: f32,
     block_size: usize,
-) -> PyResult<Scanner> {
-    match (&pssm.try_borrow()?.data, &sequence.try_borrow()?.data) {
+) -> PyResult<Bound<'py, Scanner>> {
+    let py = pssm.py();
+    match (
+        &pssm.try_borrow()?.data,
+        &mut sequence.try_borrow_mut()?.data,
+    ) {
         (ScoringMatrixData::Dna(p), StripedSequenceData::Dna(s)) => {
-            // Move data to the heap so that it doesn't get unallocated
-            // while the scanner is in use (the scanner needs to have
-            // access to the data through a reference)
-            let p = Box::new(p.clone());
-            let mut s = Box::new(s.clone());
             s.configure(&p);
             // transmute (!!!!!) the scanner so that its lifetime is 'static
             // (i.e. the reference to the PSSM and sequence never expire),
             // which is possible because we are building a self-referential
             // struct
             let scanner = unsafe {
-                let mut s = lightmotif::scan::Scanner::<Dna, C>::new(&p, &s);
-                s.threshold(threshold);
-                s.block_size(block_size);
-                std::mem::transmute(s)
+                let mut scanner = lightmotif::scan::Scanner::<Dna, _, _, C>::new(p, s);
+                scanner.threshold(threshold);
+                scanner.block_size(block_size);
+                std::mem::transmute(scanner)
             };
-            Ok(Scanner {
-                data: scanner,
-                pssm: Pin::new(p),
-                sequence: Pin::new(s),
-            })
+            Bound::new(
+                py,
+                Scanner {
+                    data: scanner,
+                    pssm: pssm.unbind(),
+                    sequence: sequence.unbind(),
+                },
+            )
         }
         (ScoringMatrixData::Protein(_), StripedSequenceData::Protein(_)) => {
             Err(PyValueError::new_err("protein scanner is not supported"))
diff --git a/lightmotif/src/scan.rs b/lightmotif/src/scan.rs
index 86054b3..d9adc16 100644
--- a/lightmotif/src/scan.rs
+++ b/lightmotif/src/scan.rs
@@ -93,10 +93,16 @@ impl Ord for Hit {
 
 /// A scanner for iterating over scoring matrix hits in a sequence.
 #[derive(Debug)]
-pub struct Scanner<'a, A: Alphabet, C: StrictlyPositive = DefaultColumns> {
-    pssm: &'a ScoringMatrix<A>,
+pub struct Scanner<
+    'a,
+    A: Alphabet,
+    M: AsRef<ScoringMatrix<A>>,
+    S: AsRef<StripedSequence<A, C>>,
+    C: StrictlyPositive = DefaultColumns,
+> {
+    pssm: M,
     dm: DiscreteMatrix<A>,
-    seq: &'a StripedSequence<A, C>,
+    seq: S,
     scores: CowMut<'a, StripedScores<f32, C>>,
     dscores: StripedScores<u8, C>,
     threshold: f32,
@@ -106,13 +112,17 @@ pub struct Scanner<'a, A: Alphabet, C: StrictlyPositive = DefaultColumns> {
     pipeline: Pipeline<A, Dispatch>,
 }
 
-impl<'a, A: Alphabet, C: StrictlyPositive> Scanner<'a, A, C> {
+impl<'a, A, M, S, C> Scanner<'a, A, M, S, C>
+where
+    A: Alphabet,
+    C: StrictlyPositive,
+    M: AsRef<ScoringMatrix<A>>,
+    S: AsRef<StripedSequence<A, C>>,
+{
     /// Create a new scanner for the given matrix and sequence.
-    pub fn new(pssm: &'a ScoringMatrix<A>, seq: &'a StripedSequence<A, C>) -> Self {
+    pub fn new(pssm: M, seq: S) -> Self {
         Self {
-            pssm,
-            seq,
-            dm: pssm.to_discrete(),
+            dm: pssm.as_ref().to_discrete(),
             scores: CowMut::Owned(StripedScores::empty()),
             dscores: StripedScores::empty(),
             threshold: 0.0,
@@ -120,6 +130,8 @@ impl<'a, A: Alphabet, C: StrictlyPositive> Scanner<'a, A, C> {
             row: 0,
             hits: Vec::new(),
             pipeline: Pipeline::dispatch(),
+            pssm,
+            seq,
         }
     }
 
@@ -142,19 +154,22 @@ impl<'a, A: Alphabet, C: StrictlyPositive> Scanner<'a, A, C> {
     }
 }
 
-impl<'a, A, C> Iterator for Scanner<'a, A, C>
+impl<'a, A, M, S, C> Iterator for Scanner<'a, A, M, S, C>
 where
     A: Alphabet,
     C: StrictlyPositive,
+    M: AsRef<ScoringMatrix<A>>,
+    S: AsRef<StripedSequence<A, C>>,
     Pipeline<A, Dispatch>: Score<u8, A, C> + Threshold<u8, C> + Maximum<u8, C>,
 {
     type Item = Hit;
     fn next(&mut self) -> Option<Self::Item> {
+        let seq = self.seq.as_ref();
         let t = self.dm.scale(self.threshold);
-        while self.hits.is_empty() && self.row < self.seq.matrix().rows() {
+        while self.hits.is_empty() && self.row < seq.matrix().rows() {
             // compute the row slice to score in the striped sequence matrix
-            let end = (self.row + self.block_size)
-                .min(self.seq.matrix().rows().saturating_sub(self.seq.wrap()));
+            let end =
+                (self.row + self.block_size).min(seq.matrix().rows().saturating_sub(seq.wrap()));
             // score the row slice
             self.pipeline
                 .score_rows_into(&self.dm, &self.seq, self.row..end, &mut self.dscores);
@@ -163,9 +178,8 @@ where
                 // scan through the positions above discrete threshold and recompute
                 // scores in floating-point to see if they pass the real threshold.
                 for c in self.pipeline.threshold(&self.dscores, t) {
-                    let index =
-                        c.col * (self.seq.matrix().rows() - self.seq.wrap()) + self.row + c.row;
-                    let score = self.pssm.score_position(&self.seq, index);
+                    let index = c.col * (seq.matrix().rows() - seq.wrap()) + self.row + c.row;
+                    let score = self.pssm.as_ref().score_position(seq, index);
                     if score >= self.threshold {
                         self.hits.push(Hit::new(index, score));
                     }
@@ -178,6 +192,8 @@ where
     }
 
     fn max(mut self) -> Option<Self::Item> {
+        let seq = self.seq.as_ref();
+
         // Compute the score of the best hit not yet returned, and translate
         // the `f32` score threshold into a discrete, under-estimate `u8`
         // threshold.
@@ -190,15 +206,15 @@ where
             None => self.dm.scale(self.threshold),
         };
 
-        // Cache th number of sequence rows in the striped sequence matrix.
-        let sequence_rows = self.seq.matrix().rows() - self.seq.wrap();
+        // Cache the number of sequence rows in the striped sequence matrix.
+        let sequence_rows = seq.matrix().rows() - seq.wrap();
 
         // Process all rows of the sequence and record the local
-        while self.row < self.seq.matrix().rows() {
+        while self.row < seq.matrix().rows() {
             // Score the rows of the current block.
             let end = (self.row + self.block_size).min(sequence_rows);
             self.pipeline
-                .score_rows_into(&self.dm, &self.seq, self.row..end, &mut self.dscores);
+                .score_rows_into(&self.dm, seq, self.row..end, &mut self.dscores);
             // Check if the highest score in the block is high enough to be
             // a new global maximum
             if self.pipeline.max(&self.dscores).unwrap() >= best_discrete {
@@ -208,7 +224,7 @@ where
                     let dscore = self.dscores.matrix()[c];
                     if dscore >= best_discrete {
                         let index = c.col * sequence_rows + self.row + c.row;
-                        let score = self.pssm.score_position(&self.seq, index);
+                        let score = self.pssm.as_ref().score_position(&self.seq, index);
                         if let Some(hit) = &best {
                             if (score > hit.score) | (score == hit.score && index > hit.position) {
                                 best = Some(Hit::new(index, score));
diff --git a/lightmotif/src/seq.rs b/lightmotif/src/seq.rs
index 85e4646..0a03476 100644
--- a/lightmotif/src/seq.rs
+++ b/lightmotif/src/seq.rs
@@ -10,7 +10,6 @@ use std::str::FromStr;
 use rand::Rng;
 
 use super::abc::Alphabet;
-use super::abc::Background;
 use super::abc::Symbol;
 use super::dense::DenseMatrix;
 use super::err::InvalidData;
-- 
GitLab