Skip to content
Snippets Groups Projects
Commit 453bdc99 authored by Martin Larralde's avatar Martin Larralde
Browse files

Update `Encode` trait to return an `EncodedSequence`

parent 5aab90ca
No related branches found
No related tags found
No related merge requests found
......@@ -4,11 +4,14 @@ extern crate generic_array;
extern crate lightmotif;
extern crate pyo3;
use std::sync::RwLock;
use lightmotif::abc::Alphabet;
use lightmotif::abc::Dna;
use lightmotif::abc::Symbol;
use lightmotif::dense::DenseMatrix;
use lightmotif::num::Unsigned;
use lightmotif::pli::dispatch::Dispatch;
use lightmotif::pli::platform::Backend;
use lightmotif::pli::Encode;
use lightmotif::pli::Maximum;
......@@ -19,6 +22,7 @@ use lightmotif::pli::Threshold;
use pyo3::exceptions::PyBufferError;
use pyo3::exceptions::PyIndexError;
use pyo3::exceptions::PyRuntimeError;
use pyo3::exceptions::PyTypeError;
use pyo3::exceptions::PyValueError;
use pyo3::ffi::Py_ssize_t;
......@@ -46,6 +50,28 @@ type C = typenum::consts::U1;
// --- Helpers -----------------------------------------------------------------
static PIPELINE: RwLock<Option<Pipeline<Dna, Dispatch>>> = RwLock::new(None);
fn init_pipeline() -> PyResult<()> {
PIPELINE
.write()
.map_err(|_| PyRuntimeError::new_err("Failed to acquire global pipeline"))?
.replace(Pipeline::dispatch());
Ok(())
}
fn with_pipeline<T, F>(f: F) -> PyResult<T>
where
F: FnOnce(&Pipeline<Dna, Dispatch>) -> T,
{
PIPELINE
.read()
.map_err(|_| PyRuntimeError::new_err("Failed to acquire global lock"))?
.as_ref()
.ok_or_else(|| PyRuntimeError::new_err("Global pipeline was not initialize"))
.map(f)
}
fn dict_to_alphabet_array<'py, A: lightmotif::Alphabet>(
d: &'py PyDict,
) -> PyResult<GenericArray<f32, A::K>> {
......@@ -91,7 +117,7 @@ impl EncodedSequence {
.map_err(|lightmotif::err::InvalidSymbol(x)| {
PyValueError::new_err(format!("Invalid symbol in input: {}", x))
})?;
Ok(EncodedSequence::from(data).into())
Ok(Self::from(data).into())
}
/// Create a copy of this sequence.
......@@ -100,15 +126,14 @@ impl EncodedSequence {
}
/// Convert this sequence into a striped matrix.
pub fn stripe(&self) -> StripedSequence {
let data = Pipeline::dispatch().stripe(&self.data);
StripedSequence { data }
pub fn stripe(&self) -> PyResult<StripedSequence> {
with_pipeline(|pli| StripedSequence::from(pli.stripe(&self.data)))
}
}
impl From<Vec<<Dna as Alphabet>::Symbol>> for EncodedSequence {
fn from(v: Vec<<Dna as Alphabet>::Symbol>) -> Self {
Self { data: v.into() }
impl From<lightmotif::seq::EncodedSequence<lightmotif::Dna>> for EncodedSequence {
fn from(data: lightmotif::seq::EncodedSequence<lightmotif::Dna>) -> Self {
Self { data }
}
}
......@@ -121,6 +146,12 @@ pub struct StripedSequence {
data: lightmotif::seq::StripedSequence<lightmotif::Dna, C>,
}
impl From<lightmotif::seq::StripedSequence<lightmotif::Dna, C>> for StripedSequence {
fn from(data: lightmotif::seq::StripedSequence<lightmotif::Dna, C>) -> Self {
Self { data }
}
}
// --- CountMatrix -------------------------------------------------------------
/// A matrix storing the count of a motif letters at each position.
......@@ -312,7 +343,7 @@ impl ScoringMatrix {
let scores = slf
.py()
.allow_threads(|| Pipeline::dispatch().score(seq, pssm));
.allow_threads(|| with_pipeline(|pli| pli.score(seq, pssm)))?;
Ok(StripedScores::from(scores))
}
......@@ -398,7 +429,7 @@ impl StripedScores {
let scores = &slf.scores;
let indices = slf
.py()
.allow_threads(|| Pipeline::<Dna, _>::dispatch().threshold(scores, threshold));
.allow_threads(|| with_pipeline(|pli| pli.threshold(scores, threshold)))?;
Ok(indices)
}
......@@ -415,7 +446,7 @@ impl StripedScores {
let scores = &slf.scores;
let indices = slf
.py()
.allow_threads(|| Pipeline::<Dna, _>::dispatch().max(scores));
.allow_threads(|| with_pipeline(|pli| pli.max(scores)))?;
Ok(indices)
}
......@@ -432,7 +463,7 @@ impl StripedScores {
let scores = &slf.scores;
let indices = slf
.py()
.allow_threads(|| Pipeline::<Dna, _>::dispatch().argmax(scores));
.allow_threads(|| with_pipeline(|pli| pli.argmax(scores)))?;
Ok(indices)
}
}
......@@ -524,7 +555,7 @@ pub fn stripe<'py>(sequence: &'py PyAny) -> PyResult<StripedSequence> {
let s = sequence.extract::<&PyString>()?;
let encoded = EncodedSequence::__init__(s).and_then(|e| Py::new(py, e))?;
let striped = encoded.borrow(py).stripe();
Ok(striped)
striped
}
/// PyO3 bindings to ``lightmotif``, a library for fast PWM motif scanning.
......@@ -556,5 +587,6 @@ pub fn init(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(create, m)?)?;
m.add_function(wrap_pyfunction!(stripe, m)?)?;
init_pipeline()?;
Ok(())
}
......@@ -20,6 +20,7 @@ use super::err::InvalidSymbol;
use super::err::UnsupportedBackend;
use super::num::StrictlyPositive;
use super::pwm::ScoringMatrix;
use super::seq::EncodedSequence;
use super::seq::StripedSequence;
use typenum::consts::U16;
......@@ -35,7 +36,7 @@ mod scores;
/// Used for encoding a sequence into rank-based encoding.
pub trait Encode<A: Alphabet> {
/// Encode the given sequence into a vector of symbols.
fn encode<S: AsRef<[u8]>>(&self, seq: S) -> Result<Vec<A::Symbol>, InvalidSymbol> {
fn encode_raw<S: AsRef<[u8]>>(&self, seq: S) -> Result<Vec<A::Symbol>, InvalidSymbol> {
let s = seq.as_ref();
let mut buffer = Vec::with_capacity(s.len());
unsafe { buffer.set_len(s.len()) };
......@@ -45,6 +46,11 @@ pub trait Encode<A: Alphabet> {
}
}
/// Encode the given sequence into an `EncodedSequence`.
fn encode<S: AsRef<[u8]>>(&self, seq: S) -> Result<EncodedSequence<A>, InvalidSymbol> {
self.encode_raw(seq).map(EncodedSequence::new)
}
/// Encode the given sequence into a buffer of symbols.
///
/// The destination buffer is expected to be large enough to store the
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment