Skip to content
Snippets Groups Projects
Commit 085a85c2 authored by Martin Larralde's avatar Martin Larralde
Browse files

Add new trait to search positions above threshold in `StripedScores`

parent a4c07b1f
No related branches found
No related tags found
No related merge requests found
......@@ -81,14 +81,11 @@ pub trait BestPosition<C: StrictlyPositive> {
return None;
}
let data = scores.matrix();
let mut best_pos = 0;
let mut best_score = data[0][0];
for i in 0..scores.len() {
let col = i / data.rows();
let row = i % data.rows();
if data[row][col] > best_score {
best_score = data[row][col];
let mut best_score = scores[0];
for i in 1..scores.len() {
if scores[i] > best_score {
best_score = scores[i];
best_pos = i;
}
}
......@@ -97,6 +94,24 @@ pub trait BestPosition<C: StrictlyPositive> {
}
}
/// Generic trait for finding positions above a score threshold in a striped score matrix.
pub trait Threshold<C: StrictlyPositive> {
/// Return the indices of positions with score equal to or greater than the threshold.
///
/// # Note
///
/// The indices may or may not be sorted, depending on the implementation.
fn threshold(&self, scores: &StripedScores<C>, threshold: f32) -> Vec<usize> {
let mut positions = Vec::new();
for i in 0..scores.len() {
if scores[i] >= threshold {
positions.push(i);
}
}
positions
}
}
// --- Pipeline ----------------------------------------------------------------
/// Wrapper implementing score computation for different platforms.
......@@ -122,6 +137,8 @@ impl<A: Alphabet, C: StrictlyPositive> Score<A, C> for Pipeline<A, Generic> {}
impl<A: Alphabet, C: StrictlyPositive> BestPosition<C> for Pipeline<A, Generic> {}
impl<A: Alphabet, C: StrictlyPositive> Threshold<C> for Pipeline<A, Generic> {}
// --- SSE2 pipeline -----------------------------------------------------------
impl<A: Alphabet> Pipeline<A, Sse2> {
......@@ -166,6 +183,8 @@ where
}
}
impl<A: Alphabet, C: StrictlyPositive> Threshold<C> for Pipeline<A, Sse2> {}
// --- AVX2 pipeline -----------------------------------------------------------
impl<A: Alphabet> Pipeline<A, Avx2> {
......@@ -198,6 +217,7 @@ impl<A: Alphabet> BestPosition<<Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {
Avx2::best_position(scores)
}
}
impl<A: Alphabet> Threshold<<Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {}
// --- NEON pipeline -----------------------------------------------------------
......@@ -240,3 +260,12 @@ where
<C as Div<U16>>::Output: Unsigned,
{
}
impl<A, C> Threshold<C> for Pipeline<A, Neon>
where
A: Alphabet,
C: StrictlyPositive + Rem<U16> + Div<U16>,
<C as Rem<U16>>::Output: Zero,
<C as Div<U16>>::Output: Unsigned,
{
}
......@@ -7,10 +7,7 @@ use std::ops::Range;
use typenum::marker_traits::NonZero;
use typenum::marker_traits::Unsigned;
use crate::abc::Alphabet;
use crate::dense::DenseMatrix;
use crate::pwm::ScoringMatrix;
use crate::seq::StripedSequence;
/// Striped matrix storing scores for an equally striped sequence.
#[derive(Clone, Debug)]
......
......@@ -9,6 +9,7 @@ use lightmotif::num::U32;
use lightmotif::pli::BestPosition;
use lightmotif::pli::Pipeline;
use lightmotif::pli::Score;
use lightmotif::pli::Threshold;
use lightmotif::pwm::CountMatrix;
use lightmotif::seq::EncodedSequence;
use lightmotif::seq::StripedSequence;
......@@ -77,6 +78,25 @@ fn test_best_position<C: StrictlyPositive, P: Score<Dna, C> + BestPosition<C>>(p
assert_eq!(pli.best_position(&result), Some(18));
}
fn test_threshold<C: StrictlyPositive, P: Score<Dna, C> + Threshold<C>>(pli: &P) {
let mut striped = StripedSequence::<Dna, C>::encode(SEQUENCE).unwrap();
let cm = CountMatrix::<Dna>::from_sequences(
PATTERNS.iter().map(|x| EncodedSequence::encode(x).unwrap()),
)
.unwrap();
let pbm = cm.to_freq(0.1);
let pwm = pbm.to_weight(None);
let pssm = pwm.into();
striped.configure(&pssm);
let result = pli.score(&striped, &pssm);
let mut positions = pli.threshold(&result, -10.0);
positions.sort_unstable();
assert_eq!(positions, vec![18, 27, 32]);
}
#[test]
fn test_score_generic() {
let pli = Pipeline::generic();
......@@ -91,6 +111,12 @@ fn test_best_position_generic() {
test_best_position::<U1, _>(&pli);
}
#[test]
fn test_threshold_generic() {
let pli = Pipeline::generic();
test_threshold::<U32, _>(&pli);
}
#[cfg(target_feature = "sse2")]
#[test]
fn test_score_sse2() {
......@@ -105,6 +131,13 @@ fn test_best_position_sse2() {
test_best_position::<U16, _>(&pli);
}
#[cfg(target_feature = "sse2")]
#[test]
fn test_threshold_sse2() {
let pli = Pipeline::sse2().unwrap();
test_threshold::<U16, _>(&pli);
}
#[cfg(target_feature = "sse2")]
#[test]
fn test_score_sse2_32() {
......@@ -119,6 +152,13 @@ fn test_best_position_sse2_32() {
test_best_position::<U32, _>(&pli);
}
#[cfg(target_feature = "sse2")]
#[test]
fn test_threshold_sse2_32() {
let pli = Pipeline::sse2().unwrap();
test_threshold::<U32, _>(&pli);
}
#[cfg(target_feature = "avx2")]
#[test]
fn test_score_avx2() {
......@@ -133,6 +173,13 @@ fn test_best_position_avx2() {
test_best_position(&pli);
}
#[cfg(target_feature = "avx2")]
#[test]
fn test_threshold_avx2() {
let pli = Pipeline::avx2().unwrap();
test_threshold::<U32, _>(&pli);
}
#[cfg(target_feature = "neon")]
#[test]
fn test_score_neon() {
......@@ -146,3 +193,10 @@ fn test_best_position_neon() {
let pli = Pipeline::neon().unwrap();
test_best_position::<U16, _>(&pli);
}
#[cfg(target_feature = "neon")]
#[test]
fn test_threshold_neon() {
let pli = Pipeline::neon().unwrap();
test_threshold::<U16, _>(&pli);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment