Skip to content
Snippets Groups Projects
Commit 5c36d18f authored by Martin Larralde's avatar Martin Larralde
Browse files

Wrap the new DNA scanning algorithm in the `lightmotif-py` library

parent 76d0eda1
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,7 @@ from .lib import ( ...@@ -11,6 +11,7 @@ from .lib import (
StripedScores, StripedScores,
create, create,
stripe, stripe,
scan,
) )
__author__ = lib.__author__ __author__ = lib.__author__
......
...@@ -8,6 +8,7 @@ extern crate pyo3; ...@@ -8,6 +8,7 @@ extern crate pyo3;
use std::fmt::Display; use std::fmt::Display;
use std::fmt::Formatter; use std::fmt::Formatter;
use std::pin::Pin;
use lightmotif::abc::Alphabet; use lightmotif::abc::Alphabet;
use lightmotif::abc::Dna; use lightmotif::abc::Dna;
...@@ -1008,7 +1009,52 @@ pub struct Motif { ...@@ -1008,7 +1009,52 @@ pub struct Motif {
pssm: Py<ScoringMatrix>, pssm: Py<ScoringMatrix>,
} }
// --- Module ------------------------------------------------------------------ // --- Scanner -----------------------------------------------------------------
#[derive(Debug)]
enum ScannerData<'py> {
Dna(lightmotif::scan::Scanner<'py, Dna, C>),
// Protein(lightmotif::scan::Scanner::<'py, Protein, C>),
}
#[pyclass(module = "lightmotif.lib")]
#[derive(Debug)]
pub struct Scanner {
pssm: Pin<Box<lightmotif::pwm::ScoringMatrix<Dna>>>,
sequence: Pin<Box<lightmotif::seq::StripedSequence<Dna, C>>>,
data: lightmotif::scan::Scanner<'static, Dna, C>,
}
#[pymethods]
impl Scanner {
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
slf
}
fn __next__(mut slf: PyRefMut<Self>) -> Option<Hit> {
slf.data.next().map(Hit::from)
}
}
#[pyclass(module = "lightmotif.lib")]
#[derive(Debug)]
pub struct Hit {
#[pyo3(get)]
position: usize,
#[pyo3(get)]
score: f32,
}
impl From<lightmotif::scan::Hit> for Hit {
fn from(value: lightmotif::scan::Hit) -> Self {
Hit {
position: value.position,
score: value.score,
}
}
}
// --- Functions ---------------------------------------------------------------
/// Create a new motif from an iterable of sequences. /// Create a new motif from an iterable of sequences.
/// ///
...@@ -1080,6 +1126,48 @@ pub fn stripe(sequence: Bound<PyAny>, protein: bool) -> PyResult<StripedSequence ...@@ -1080,6 +1126,48 @@ pub fn stripe(sequence: Bound<PyAny>, protein: bool) -> PyResult<StripedSequence
striped striped
} }
/// Scan a sequence using a fast scanner implementation to identify hits.
#[pyfunction]
#[pyo3(signature = (pssm, sequence, threshold = 0.0, block_size = 256))]
pub fn scan(
pssm: Bound<ScoringMatrix>,
sequence: Bound<StripedSequence>,
threshold: f32,
block_size: usize,
) -> PyResult<Scanner> {
match (&pssm.try_borrow()?.data, &sequence.try_borrow()?.data) {
(ScoringMatrixData::Dna(p), StripedSequenceData::Dna(s)) => {
// Move data to the heap so that it doesn't get unallocated
// while the scanner is in use (the scanner needs to have
// access to the data through a reference)
let p = Box::new(p.clone());
let mut s = Box::new(s.clone());
s.configure(&p);
// transmute (!!!!!) the scanner so that its lifetime is 'static
// (i.e. the reference to the PSSM and sequence never expire),
// which is possible because we are building a self-referential
// struct
let scanner = unsafe {
let mut s = lightmotif::scan::Scanner::<Dna, C>::new(&p, &s);
s.threshold(threshold);
s.block_size(block_size);
std::mem::transmute(s)
};
Ok(Scanner {
data: scanner,
pssm: Pin::new(p),
sequence: Pin::new(s),
})
}
(ScoringMatrixData::Protein(_), StripedSequenceData::Protein(_)) => {
Err(PyValueError::new_err("protein scanner is not supported"))
}
(_, _) => Err(PyValueError::new_err("alphabet mismatch")),
}
}
// --- Module ------------------------------------------------------------------
/// PyO3 bindings to ``lightmotif``, a library for fast PWM motif scanning. /// PyO3 bindings to ``lightmotif``, a library for fast PWM motif scanning.
/// ///
/// The API is similar to the `Bio.motifs` module from Biopython on purpose. /// The API is similar to the `Bio.motifs` module from Biopython on purpose.
...@@ -1122,7 +1210,11 @@ pub fn init<'py>(_py: Python<'py>, m: &Bound<PyModule>) -> PyResult<()> { ...@@ -1122,7 +1210,11 @@ pub fn init<'py>(_py: Python<'py>, m: &Bound<PyModule>) -> PyResult<()> {
m.add_class::<Motif>()?; m.add_class::<Motif>()?;
m.add_class::<Scanner>()?;
m.add_class::<Hit>()?;
m.add_function(wrap_pyfunction!(create, m)?)?; m.add_function(wrap_pyfunction!(create, m)?)?;
m.add_function(wrap_pyfunction!(scan, m)?)?;
m.add_function(wrap_pyfunction!(stripe, m)?)?; m.add_function(wrap_pyfunction!(stripe, m)?)?;
Ok(()) Ok(())
......
from . import test_doctest, test_pipeline, test_sequence, test_pvalue from . import (
test_doctest,
test_pipeline,
test_scanner,
test_sequence,
test_pvalue,
)
def load_tests(loader, suite, pattern): def load_tests(loader, suite, pattern):
suite.addTests(loader.loadTestsFromModule(test_doctest))
suite.addTests(loader.loadTestsFromModule(test_pipeline)) suite.addTests(loader.loadTestsFromModule(test_pipeline))
suite.addTests(loader.loadTestsFromModule(test_sequence))
suite.addTests(loader.loadTestsFromModule(test_pvalue)) suite.addTests(loader.loadTestsFromModule(test_pvalue))
suite.addTests(loader.loadTestsFromModule(test_doctest)) suite.addTests(loader.loadTestsFromModule(test_scanner))
suite.addTests(loader.loadTestsFromModule(test_sequence))
return suite return suite
import gzip
import os
import tempfile
import unittest
import lightmotif
SEQUENCE = "ATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCG"
EXPECTED = [
-23.07094,
-18.678621,
-15.219191,
-17.745737,
-18.678621,
-23.07094,
-17.745737,
-19.611507,
-27.463257,
-29.989803,
-14.286304,
-26.53037,
-15.219191,
-10.826873,
-10.826873,
-22.138054,
-38.774437,
-30.922688,
-5.50167,
-24.003826,
-18.678621,
-15.219191,
-35.315006,
-17.745737,
-10.826873,
-30.922688,
-23.07094,
-6.4345555,
-31.855574,
-23.07094,
-15.219191,
-31.855574,
-8.961102,
-26.53037,
-27.463257,
-14.286304,
-15.219191,
-26.53037,
-23.07094,
-18.678621,
-14.286304,
-18.678621,
-26.53037,
-16.152077,
-17.745737,
-18.678621,
-17.745737,
-14.286304,
-30.922688,
-18.678621,
]
class TestScanner(unittest.TestCase):
def test_scan(self):
motif = lightmotif.create(["GTTGACCTTATCAAC", "GTTGATCCAGTCAAC"])
frequencies = motif.counts.normalize(0.1)
pssm = frequencies.log_odds()
seq = lightmotif.stripe(SEQUENCE)
hits = list(lightmotif.scan(pssm, seq))
self.assertEqual(len(hits), 0)
hits = list(lightmotif.scan(pssm, seq, threshold=-10.0))
self.assertEqual(len(hits), 3)
hits.sort(key=lambda h: h.position)
self.assertAlmostEqual(hits[0].score, -5.50167, places=5)
self.assertAlmostEqual(hits[1].score, -6.4345555, places=5)
self.assertAlmostEqual(hits[2].score, -8.961102, places=5)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment