Skip to content
Snippets Groups Projects
Commit d14173a9 authored by Martin Larralde's avatar Martin Larralde
Browse files

Make `StripedScores` generic over the score type

parent 3141f9f7
No related branches found
No related tags found
No related merge requests found
......@@ -570,7 +570,7 @@ impl From<lightmotif::ScoringMatrix<lightmotif::Dna>> for ScoringMatrix {
#[pyclass(module = "lightmotif.lib", sequence)]
#[derive(Clone, Debug)]
pub struct StripedScores {
scores: lightmotif::scores::StripedScores<C>,
scores: lightmotif::scores::StripedScores<f32, C>,
shape: [Py_ssize_t; 2],
strides: [Py_ssize_t; 2],
}
......@@ -668,8 +668,8 @@ impl StripedScores {
}
}
impl From<lightmotif::scores::StripedScores<C>> for StripedScores {
fn from(scores: lightmotif::scores::StripedScores<C>) -> Self {
impl From<lightmotif::scores::StripedScores<f32, C>> for StripedScores {
fn from(scores: lightmotif::scores::StripedScores<f32, C>) -> Self {
// assert_eq!(scores.range().start, 0);
// extract the matrix shape
let cols = scores.matrix().columns();
......
......@@ -78,7 +78,7 @@ impl Score<Dna, <Dispatch as Backend>::LANES> for Pipeline<Dna, Dispatch> {
pssm: M,
seq: S,
rows: Range<usize>,
scores: &mut StripedScores<<Dispatch as Backend>::LANES>,
scores: &mut StripedScores<f32, <Dispatch as Backend>::LANES>,
) where
S: AsRef<StripedSequence<Dna, <Dispatch as Backend>::LANES>>,
M: AsRef<ScoringMatrix<Dna>>,
......@@ -107,7 +107,7 @@ impl Score<Protein, <Dispatch as Backend>::LANES> for Pipeline<Protein, Dispatch
pssm: M,
seq: S,
rows: Range<usize>,
scores: &mut StripedScores<<Dispatch as Backend>::LANES>,
scores: &mut StripedScores<f32, <Dispatch as Backend>::LANES>,
) where
S: AsRef<StripedSequence<Protein, <Dispatch as Backend>::LANES>>,
M: AsRef<ScoringMatrix<Protein>>,
......@@ -149,7 +149,7 @@ impl<A: Alphabet> Stripe<A, <Dispatch as Backend>::LANES> for Pipeline<A, Dispat
impl<A: Alphabet> Maximum<<Dispatch as Backend>::LANES> for Pipeline<A, Dispatch> {
fn argmax(
&self,
scores: &StripedScores<<Dispatch as Backend>::LANES>,
scores: &StripedScores<f32, <Dispatch as Backend>::LANES>,
) -> Option<MatrixCoordinates> {
match self.backend {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
......
......@@ -73,7 +73,7 @@ pub trait Score<A: Alphabet, C: StrictlyPositive> {
pssm: M,
seq: S,
rows: Range<usize>,
scores: &mut StripedScores<C>,
scores: &mut StripedScores<f32, C>,
) where
S: AsRef<StripedSequence<A, C>>,
M: AsRef<ScoringMatrix<A>>,
......@@ -105,7 +105,7 @@ pub trait Score<A: Alphabet, C: StrictlyPositive> {
}
/// Compute the PSSM scores into the given striped score matrix.
fn score_into<S, M>(&self, pssm: M, seq: S, scores: &mut StripedScores<C>)
fn score_into<S, M>(&self, pssm: M, seq: S, scores: &mut StripedScores<f32, C>)
where
S: AsRef<StripedSequence<A, C>>,
M: AsRef<ScoringMatrix<A>>,
......@@ -117,7 +117,7 @@ pub trait Score<A: Alphabet, C: StrictlyPositive> {
}
/// Compute the PSSM scores for every sequence positions.
fn score<S, M>(&self, pssm: M, seq: S) -> StripedScores<C>
fn score<S, M>(&self, pssm: M, seq: S) -> StripedScores<f32, C>
where
S: AsRef<StripedSequence<A, C>>,
M: AsRef<ScoringMatrix<A>>,
......@@ -133,7 +133,7 @@ pub trait Score<A: Alphabet, C: StrictlyPositive> {
/// Used for finding the highest scoring site in a striped score matrix.
pub trait Maximum<C: StrictlyPositive> {
/// Find the matrix coordinates with the highest score.
fn argmax(&self, scores: &StripedScores<C>) -> Option<MatrixCoordinates> {
fn argmax(&self, scores: &StripedScores<f32, C>) -> Option<MatrixCoordinates> {
if scores.is_empty() {
return None;
}
......@@ -156,7 +156,7 @@ pub trait Maximum<C: StrictlyPositive> {
}
/// Find the highest score.
fn max(&self, scores: &StripedScores<C>) -> Option<f32> {
fn max(&self, scores: &StripedScores<f32, C>) -> Option<f32> {
self.argmax(scores).map(|c| scores.matrix()[c])
}
}
......@@ -204,7 +204,7 @@ pub trait Threshold<C: StrictlyPositive> {
/// # Note
/// The indices are not be sorted, and the actual order depends on the
/// implementation.
fn threshold(&self, scores: &StripedScores<C>, threshold: f32) -> Vec<MatrixCoordinates> {
fn threshold(&self, scores: &StripedScores<f32, C>, threshold: f32) -> Vec<MatrixCoordinates> {
let mut positions = Vec::new();
for (i, row) in scores.matrix().iter().enumerate() {
for col in 0..C::USIZE {
......@@ -323,7 +323,7 @@ where
pssm: M,
seq: S,
rows: Range<usize>,
scores: &mut StripedScores<C>,
scores: &mut StripedScores<f32, C>,
) where
S: AsRef<StripedSequence<A, C>>,
M: AsRef<ScoringMatrix<A>>,
......@@ -337,7 +337,7 @@ where
A: Alphabet,
C: StrictlyPositive + MultipleOf<U16>,
{
fn argmax(&self, scores: &StripedScores<C>) -> Option<MatrixCoordinates> {
fn argmax(&self, scores: &StripedScores<f32, C>) -> Option<MatrixCoordinates> {
Sse2::argmax(scores)
}
}
......@@ -378,7 +378,7 @@ impl Score<Dna, <Avx2 as Backend>::LANES> for Pipeline<Dna, Avx2> {
pssm: M,
seq: S,
rows: Range<usize>,
scores: &mut StripedScores<<Avx2 as Backend>::LANES>,
scores: &mut StripedScores<f32, <Avx2 as Backend>::LANES>,
) where
S: AsRef<StripedSequence<Dna, <Avx2 as Backend>::LANES>>,
M: AsRef<ScoringMatrix<Dna>>,
......@@ -393,7 +393,7 @@ impl Score<Protein, <Avx2 as Backend>::LANES> for Pipeline<Protein, Avx2> {
pssm: M,
seq: S,
rows: Range<usize>,
scores: &mut StripedScores<<Avx2 as Backend>::LANES>,
scores: &mut StripedScores<f32, <Avx2 as Backend>::LANES>,
) where
S: AsRef<StripedSequence<Protein, <Avx2 as Backend>::LANES>>,
M: AsRef<ScoringMatrix<Protein>>,
......@@ -416,7 +416,7 @@ impl<A: Alphabet> Stripe<A, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {
impl<A: Alphabet> Maximum<<Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {
fn argmax(
&self,
scores: &StripedScores<<Avx2 as Backend>::LANES>,
scores: &StripedScores<f32, <Avx2 as Backend>::LANES>,
) -> Option<MatrixCoordinates> {
Avx2::argmax(scores)
}
......
......@@ -99,7 +99,7 @@ unsafe fn score_avx2_permute<A>(
pssm: &ScoringMatrix<A>,
seq: &StripedSequence<A, <Avx2 as Backend>::LANES>,
rows: Range<usize>,
scores: &mut StripedScores<<Avx2 as Backend>::LANES>,
scores: &mut StripedScores<f32, <Avx2 as Backend>::LANES>,
) where
A: Alphabet,
<A as Alphabet>::K: IsLessOrEqual<U8>,
......@@ -190,7 +190,7 @@ unsafe fn score_avx2_gather<A>(
pssm: &ScoringMatrix<A>,
seq: &StripedSequence<A, <Avx2 as Backend>::LANES>,
rows: Range<usize>,
scores: &mut StripedScores<<Avx2 as Backend>::LANES>,
scores: &mut StripedScores<f32, <Avx2 as Backend>::LANES>,
) where
A: Alphabet,
{
......@@ -272,7 +272,7 @@ unsafe fn score_avx2_gather<A>(
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn argmax_avx2(
scores: &StripedScores<<Avx2 as Backend>::LANES>,
scores: &StripedScores<f32, <Avx2 as Backend>::LANES>,
) -> Option<MatrixCoordinates> {
if scores.max_index() > u32::MAX as usize {
panic!(
......@@ -615,7 +615,7 @@ impl Avx2 {
pssm: M,
seq: S,
rows: Range<usize>,
scores: &mut StripedScores<<Avx2 as Backend>::LANES>,
scores: &mut StripedScores<f32, <Avx2 as Backend>::LANES>,
) where
A: Alphabet,
<A as Alphabet>::K: IsLessOrEqual<U8>,
......@@ -652,7 +652,7 @@ impl Avx2 {
pssm: M,
seq: S,
rows: Range<usize>,
scores: &mut StripedScores<<Avx2 as Backend>::LANES>,
scores: &mut StripedScores<f32, <Avx2 as Backend>::LANES>,
) where
A: Alphabet,
S: AsRef<StripedSequence<A, <Avx2 as Backend>::LANES>>,
......@@ -683,7 +683,9 @@ impl Avx2 {
}
#[allow(unused)]
pub fn argmax(scores: &StripedScores<<Avx2 as Backend>::LANES>) -> Option<MatrixCoordinates> {
pub fn argmax(
scores: &StripedScores<f32, <Avx2 as Backend>::LANES>,
) -> Option<MatrixCoordinates> {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe {
argmax_avx2(scores)
......
......@@ -198,7 +198,7 @@ impl Neon {
pssm: M,
seq: S,
rows: Range<usize>,
scores: &mut StripedScores<C>,
scores: &mut StripedScores<f32, C>,
) where
A: Alphabet,
C: MultipleOf<U16>,
......
......@@ -34,7 +34,7 @@ unsafe fn score_sse2<A: Alphabet, C: MultipleOf<<Sse2 as Backend>::LANES>>(
pssm: &ScoringMatrix<A>,
seq: &StripedSequence<A, C>,
rows: Range<usize>,
scores: &mut StripedScores<C>,
scores: &mut StripedScores<f32, C>,
) {
// mask vectors for broadcasting uint8x16_t to uint32x4_t to floatx4_t
let zero = _mm_setzero_si128();
......@@ -97,7 +97,7 @@ unsafe fn score_sse2<A: Alphabet, C: MultipleOf<<Sse2 as Backend>::LANES>>(
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
unsafe fn argmax_sse2<C: MultipleOf<<Sse2 as Backend>::LANES>>(
scores: &StripedScores<C>,
scores: &StripedScores<f32, C>,
) -> Option<MatrixCoordinates> {
if scores.max_index() > u32::MAX as usize {
panic!(
......@@ -188,7 +188,7 @@ impl Sse2 {
pssm: M,
seq: S,
rows: Range<usize>,
scores: &mut StripedScores<C>,
scores: &mut StripedScores<f32, C>,
) where
A: Alphabet,
C: MultipleOf<<Sse2 as Backend>::LANES>,
......@@ -221,7 +221,7 @@ impl Sse2 {
#[allow(unused)]
pub fn argmax<C: MultipleOf<<Sse2 as Backend>::LANES>>(
scores: &StripedScores<C>,
scores: &StripedScores<f32, C>,
) -> Option<MatrixCoordinates> {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe {
......
......@@ -479,7 +479,7 @@ impl<A: Alphabet> ScoringMatrix<A> {
///
/// # Note
/// Uses platform-accelerated implementation when available.
pub fn score<S, C>(&self, seq: S) -> StripedScores<C>
pub fn score<S, C>(&self, seq: S) -> StripedScores<f32, C>
where
C: StrictlyPositive,
S: AsRef<StripedSequence<A, C>>,
......
......@@ -61,7 +61,7 @@ impl Hit {
pub struct Scanner<'a, A: Alphabet> {
pssm: &'a ScoringMatrix<A>,
seq: &'a StripedSequence<A, C>,
scores: CowMut<'a, StripedScores<C>>,
scores: CowMut<'a, StripedScores<f32, C>>,
threshold: f32,
block_size: usize,
row: usize,
......@@ -85,7 +85,7 @@ impl<'a, A: Alphabet> Scanner<'a, A> {
}
/// Use the given `StripedScores` as a buffer.
pub fn scores(&mut self, scores: &'a mut StripedScores<C>) -> &mut Self {
pub fn scores(&mut self, scores: &'a mut StripedScores<f32, C>) -> &mut Self {
self.scores = CowMut::Borrowed(scores);
self
}
......
......@@ -8,6 +8,7 @@ use std::ops::Range;
use crate::abc::Dna;
use crate::dense::DenseMatrix;
use crate::dense::MatrixCoordinates;
use crate::dense::MatrixElement;
use crate::err::InvalidData;
use crate::num::StrictlyPositive;
use crate::pli::dispatch::Dispatch;
......@@ -85,16 +86,16 @@ impl<T> From<Scores<T>> for Vec<T> {
/// Striped matrix storing scores for an equally striped sequence.
#[derive(Clone, Debug)]
pub struct StripedScores<C: StrictlyPositive> {
pub struct StripedScores<T: MatrixElement, C: StrictlyPositive> {
/// The raw data matrix storing the scores.
data: DenseMatrix<f32, C>,
data: DenseMatrix<T, C>,
/// The total length of the `StripedSequence` these scores were obtained from.
max_index: usize,
}
impl<C: StrictlyPositive> StripedScores<C> {
impl<T: MatrixElement, C: StrictlyPositive> StripedScores<T, C> {
/// Create a new striped score matrix with the given length and data.
fn new(data: DenseMatrix<f32, C>, max_index: usize) -> Result<Self, InvalidData> {
fn new(data: DenseMatrix<T, C>, max_index: usize) -> Result<Self, InvalidData> {
Ok(Self { data, max_index })
}
......@@ -114,12 +115,12 @@ impl<C: StrictlyPositive> StripedScores<C> {
}
/// Return a reference to the striped matrix storing the scores.
pub fn matrix(&self) -> &DenseMatrix<f32, C> {
pub fn matrix(&self) -> &DenseMatrix<T, C> {
&self.data
}
/// Return a mutable reference to the striped matrix storing the scores.
pub fn matrix_mut(&mut self) -> &mut DenseMatrix<f32, C> {
pub fn matrix_mut(&mut self) -> &mut DenseMatrix<T, C> {
&mut self.data
}
......@@ -137,18 +138,18 @@ impl<C: StrictlyPositive> StripedScores<C> {
/// Iterate over scores of individual sequence positions.
#[inline]
pub fn iter(&self) -> Iter<'_, C> {
pub fn iter(&self) -> Iter<'_, T, C> {
Iter::new(self)
}
/// Convert the striped scores into an array.
#[inline]
pub fn unstripe(&self) -> Scores<f32> {
self.iter().cloned().collect::<Vec<f32>>().into()
pub fn unstripe(&self) -> Scores<T> {
self.iter().cloned().collect::<Vec<T>>().into()
}
}
impl<C: StrictlyPositive> StripedScores<C>
impl<C: StrictlyPositive> StripedScores<f32, C>
where
Pipeline<Dna, Dispatch>: Maximum<C>,
{
......@@ -169,7 +170,7 @@ where
}
}
impl<C: StrictlyPositive> StripedScores<C>
impl<C: StrictlyPositive> StripedScores<f32, C>
where
Pipeline<Dna, Dispatch>: Threshold<C>,
{
......@@ -190,49 +191,49 @@ where
}
}
impl<C: StrictlyPositive> AsRef<DenseMatrix<f32, C>> for StripedScores<C> {
fn as_ref(&self) -> &DenseMatrix<f32, C> {
impl<T: MatrixElement, C: StrictlyPositive> AsRef<DenseMatrix<T, C>> for StripedScores<T, C> {
fn as_ref(&self) -> &DenseMatrix<T, C> {
self.matrix()
}
}
impl<C: StrictlyPositive> AsMut<DenseMatrix<f32, C>> for StripedScores<C> {
fn as_mut(&mut self) -> &mut DenseMatrix<f32, C> {
impl<T: MatrixElement, C: StrictlyPositive> AsMut<DenseMatrix<T, C>> for StripedScores<T, C> {
fn as_mut(&mut self) -> &mut DenseMatrix<T, C> {
self.matrix_mut()
}
}
impl<C: StrictlyPositive> Default for StripedScores<C> {
impl<T: MatrixElement, C: StrictlyPositive> Default for StripedScores<T, C> {
fn default() -> Self {
StripedScores::empty()
}
}
impl<C: StrictlyPositive> Index<usize> for StripedScores<C> {
type Output = f32;
impl<T: MatrixElement, C: StrictlyPositive> Index<usize> for StripedScores<T, C> {
type Output = T;
#[inline]
fn index(&self, index: usize) -> &f32 {
fn index(&self, index: usize) -> &T {
let col = index / self.data.rows();
let row = index % self.data.rows();
&self.data[row][col]
}
}
impl<C: StrictlyPositive> From<StripedScores<C>> for Vec<f32> {
fn from(scores: StripedScores<C>) -> Self {
impl<T: MatrixElement, C: StrictlyPositive> From<StripedScores<T, C>> for Vec<T> {
fn from(scores: StripedScores<T, C>) -> Self {
scores.iter().cloned().collect()
}
}
// --- Iter --------------------------------------------------------------------
pub struct Iter<'a, C: StrictlyPositive> {
scores: &'a StripedScores<C>,
pub struct Iter<'a, T: MatrixElement, C: StrictlyPositive> {
scores: &'a StripedScores<T, C>,
indices: Range<usize>,
}
impl<'a, C: StrictlyPositive> Iter<'a, C> {
fn new(scores: &'a StripedScores<C>) -> Self {
impl<'a, T: MatrixElement, C: StrictlyPositive> Iter<'a, T, C> {
fn new(scores: &'a StripedScores<T, C>) -> Self {
// Compute the last index
let end = scores
.max_index
......@@ -242,29 +243,29 @@ impl<'a, C: StrictlyPositive> Iter<'a, C> {
Self { scores, indices }
}
fn get(&self, i: usize) -> &'a f32 {
fn get(&self, i: usize) -> &'a T {
let col = i / self.scores.data.rows();
let row = i % self.scores.data.rows();
&self.scores.data[row][col]
}
}
impl<'a, C: StrictlyPositive> Iterator for Iter<'a, C> {
type Item = &'a f32;
impl<'a, T: MatrixElement, C: StrictlyPositive> Iterator for Iter<'a, T, C> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
self.indices.next().map(|i| self.get(i))
}
}
impl<'a, C: StrictlyPositive> ExactSizeIterator for Iter<'a, C> {
impl<'a, T: MatrixElement, C: StrictlyPositive> ExactSizeIterator for Iter<'a, T, C> {
fn len(&self) -> usize {
self.indices.len()
}
}
impl<'a, C: StrictlyPositive> FusedIterator for Iter<'a, C> {}
impl<'a, T: MatrixElement, C: StrictlyPositive> FusedIterator for Iter<'a, T, C> {}
impl<'a, C: StrictlyPositive> DoubleEndedIterator for Iter<'a, C> {
impl<'a, T: MatrixElement, C: StrictlyPositive> DoubleEndedIterator for Iter<'a, T, C> {
fn next_back(&mut self) -> Option<Self::Item> {
self.indices.next_back().map(|i| self.get(i))
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment