Skip to content
Snippets Groups Projects
Commit fd54401d authored by Martin Larralde's avatar Martin Larralde
Browse files

Update `lightmotif-py` with the new `typenum` API

parent 53b25d75
No related branches found
No related tags found
No related merge requests found
...@@ -20,6 +20,8 @@ path = "../lightmotif" ...@@ -20,6 +20,8 @@ path = "../lightmotif"
version = "0.1.1" version = "0.1.1"
[dependencies] [dependencies]
pyo3 = "0.18.3" pyo3 = "0.18.3"
typenum = "1.16"
generic-array = "0.14"
[features] [features]
default = [] default = []
......
#![doc = include_str!("../README.md")] #![doc = include_str!("../README.md")]
extern crate generic_array;
extern crate lightmotif; extern crate lightmotif;
extern crate pyo3; extern crate pyo3;
extern crate typenum;
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
use std::arch::x86_64::__m256; use std::arch::x86_64::__m256i;
#[cfg(target_arch = "x86")] #[cfg(target_arch = "x86")]
use std::arch::x86_64::__m256; use std::arch::x86_64::__m256i;
use lightmotif as lm; use lightmotif as lm;
#[allow(unused)]
use lightmotif::Alphabet; use lightmotif::Alphabet;
use lightmotif::Pipeline; use lightmotif::Pipeline;
use lightmotif::Score; use lightmotif::Score;
...@@ -24,19 +27,24 @@ use pyo3::types::PyDict; ...@@ -24,19 +27,24 @@ use pyo3::types::PyDict;
use pyo3::types::PyString; use pyo3::types::PyString;
use pyo3::AsPyPointer; use pyo3::AsPyPointer;
use generic_array::GenericArray;
use typenum::marker_traits::Unsigned;
// --- Compile-time constants -------------------------------------------------- // --- Compile-time constants --------------------------------------------------
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
const C: usize = std::mem::size_of::<__m256>(); type C = <__m256i as lightmotif::Vector>::LANES;
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
const C: usize = 1; type C = <u8 as lightmotif::Vector>::LANES;
// --- Helpers ----------------------------------------------------------------- // --- Helpers -----------------------------------------------------------------
fn dict_to_alphabet_array<'py, A: lm::Alphabet, const K: usize>( fn dict_to_alphabet_array<'py, A: lm::Alphabet>(
d: &'py PyDict, d: &'py PyDict,
) -> PyResult<[f32; K]> { ) -> PyResult<GenericArray<f32, A::K>> {
let mut p = [0.0; K]; let mut p = std::iter::repeat(0.0)
.take(A::K::USIZE)
.collect::<GenericArray<f32, A::K>>();
for (k, v) in d.iter() { for (k, v) in d.iter() {
let s = k.extract::<&PyString>()?.to_str()?; let s = k.extract::<&PyString>()?.to_str()?;
if s.len() != 1 { if s.len() != 1 {
...@@ -100,7 +108,7 @@ pub struct StripedSequence { ...@@ -100,7 +108,7 @@ pub struct StripedSequence {
#[pyclass(module = "lightmotif.lib")] #[pyclass(module = "lightmotif.lib")]
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct CountMatrix { pub struct CountMatrix {
data: lm::CountMatrix<lm::Dna, { lm::Dna::K }>, data: lm::CountMatrix<lm::Dna>,
} }
#[pymethods] #[pymethods]
...@@ -111,7 +119,7 @@ impl CountMatrix { ...@@ -111,7 +119,7 @@ impl CountMatrix {
if let Ok(x) = obj.extract::<f32>(py) { if let Ok(x) = obj.extract::<f32>(py) {
Ok(lm::Pseudocounts::from(x)) Ok(lm::Pseudocounts::from(x))
} else if let Ok(d) = obj.extract::<&PyDict>(py) { } else if let Ok(d) = obj.extract::<&PyDict>(py) {
let p = dict_to_alphabet_array::<lm::Dna, { lm::Dna::K }>(d)?; let p = dict_to_alphabet_array::<lm::Dna>(d)?;
Ok(lm::Pseudocounts::from(p)) Ok(lm::Pseudocounts::from(p))
} else { } else {
Err(PyTypeError::new_err("Invalid type for pseudocount")) Err(PyTypeError::new_err("Invalid type for pseudocount"))
...@@ -125,8 +133,8 @@ impl CountMatrix { ...@@ -125,8 +133,8 @@ impl CountMatrix {
} }
} }
impl From<lm::CountMatrix<lm::Dna, { lm::Dna::K }>> for CountMatrix { impl From<lm::CountMatrix<lm::Dna>> for CountMatrix {
fn from(data: lm::CountMatrix<lm::Dna, { lm::Dna::K }>) -> Self { fn from(data: lm::CountMatrix<lm::Dna>) -> Self {
Self { data } Self { data }
} }
} }
...@@ -136,7 +144,7 @@ impl From<lm::CountMatrix<lm::Dna, { lm::Dna::K }>> for CountMatrix { ...@@ -136,7 +144,7 @@ impl From<lm::CountMatrix<lm::Dna, { lm::Dna::K }>> for CountMatrix {
#[pyclass(module = "lightmotif.lib")] #[pyclass(module = "lightmotif.lib")]
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct WeightMatrix { pub struct WeightMatrix {
data: lm::WeightMatrix<lm::Dna, { lm::Dna::K }>, data: lm::WeightMatrix<lm::Dna>,
} }
#[pymethods] #[pymethods]
...@@ -146,7 +154,7 @@ impl WeightMatrix { ...@@ -146,7 +154,7 @@ impl WeightMatrix {
let bg = Python::with_gil(|py| { let bg = Python::with_gil(|py| {
if let Some(obj) = background { if let Some(obj) = background {
if let Ok(d) = obj.extract::<&PyDict>(py) { if let Ok(d) = obj.extract::<&PyDict>(py) {
let p = dict_to_alphabet_array::<lm::Dna, { lm::Dna::K }>(d)?; let p = dict_to_alphabet_array::<lm::Dna>(d)?;
lm::Background::new(p) lm::Background::new(p)
.map_err(|_| PyValueError::new_err("Invalid background frequencies")) .map_err(|_| PyValueError::new_err("Invalid background frequencies"))
} else { } else {
...@@ -165,8 +173,8 @@ impl WeightMatrix { ...@@ -165,8 +173,8 @@ impl WeightMatrix {
} }
} }
impl From<lm::WeightMatrix<lm::Dna, { lm::Dna::K }>> for WeightMatrix { impl From<lm::WeightMatrix<lm::Dna>> for WeightMatrix {
fn from(data: lm::WeightMatrix<lm::Dna, { lm::Dna::K }>) -> Self { fn from(data: lm::WeightMatrix<lm::Dna>) -> Self {
Self { data } Self { data }
} }
} }
...@@ -176,7 +184,7 @@ impl From<lm::WeightMatrix<lm::Dna, { lm::Dna::K }>> for WeightMatrix { ...@@ -176,7 +184,7 @@ impl From<lm::WeightMatrix<lm::Dna, { lm::Dna::K }>> for WeightMatrix {
#[pyclass(module = "lightmotif.lib")] #[pyclass(module = "lightmotif.lib")]
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct ScoringMatrix { pub struct ScoringMatrix {
data: lm::ScoringMatrix<lm::Dna, { lm::Dna::K }>, data: lm::ScoringMatrix<lm::Dna>,
} }
#[pymethods] #[pymethods]
...@@ -193,17 +201,17 @@ impl ScoringMatrix { ...@@ -193,17 +201,17 @@ impl ScoringMatrix {
let scores = slf.py().allow_threads(|| { let scores = slf.py().allow_threads(|| {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if std::is_x86_feature_detected!("avx2") { if std::is_x86_feature_detected!("avx2") {
return Pipeline::<lm::Dna, __m256>::score(seq, pssm); return Pipeline::<lm::Dna, __m256i>::score(seq, pssm);
} }
Pipeline::<lm::Dna, f32>::score(seq, pssm) Pipeline::<lm::Dna, u8>::score(seq, pssm)
}); });
Ok(StripedScores::from(scores)) Ok(StripedScores::from(scores))
} }
} }
impl From<lm::ScoringMatrix<lm::Dna, { lm::Dna::K }>> for ScoringMatrix { impl From<lm::ScoringMatrix<lm::Dna>> for ScoringMatrix {
fn from(data: lm::ScoringMatrix<lm::Dna, { lm::Dna::K }>) -> Self { fn from(data: lm::ScoringMatrix<lm::Dna>) -> Self {
Self { data } Self { data }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment