Skip to content
Snippets Groups Projects
Commit 7f70de3b authored by Martin Larralde's avatar Martin Larralde
Browse files

Add method to extract the best position from a `StripedScores` matrix with AVX2

parent a43aefce
No related branches found
No related tags found
No related merge requests found
...@@ -42,7 +42,9 @@ fn bench_generic(bencher: &mut test::Bencher) { ...@@ -42,7 +42,9 @@ fn bench_generic(bencher: &mut test::Bencher) {
bencher.bytes = seq.len() as u64; bencher.bytes = seq.len() as u64;
bencher.iter(|| { bencher.iter(|| {
Pipeline::<_, f32>::score_into(&striped, &pssm, &mut scores); Pipeline::<_, f32>::score_into(&striped, &pssm, &mut scores);
test::black_box(scores.argmax()); test::black_box(
<Pipeline<_, f32> as Score<Dna, { Dna::K }, f32, 32>>::best_position(&scores),
);
}); });
} }
...@@ -67,7 +69,7 @@ fn bench_sse2(bencher: &mut test::Bencher) { ...@@ -67,7 +69,7 @@ fn bench_sse2(bencher: &mut test::Bencher) {
bencher.bytes = seq.len() as u64; bencher.bytes = seq.len() as u64;
bencher.iter(|| { bencher.iter(|| {
Pipeline::<_, __m128>::score_into(&striped, &pssm, &mut scores); Pipeline::<_, __m128>::score_into(&striped, &pssm, &mut scores);
test::black_box(scores.argmax()); test::black_box(Pipeline::<_, __m128>::best_position(&scores));
}); });
} }
...@@ -92,7 +94,7 @@ fn bench_avx2(bencher: &mut test::Bencher) { ...@@ -92,7 +94,7 @@ fn bench_avx2(bencher: &mut test::Bencher) {
bencher.bytes = seq.len() as u64; bencher.bytes = seq.len() as u64;
bencher.iter(|| { bencher.iter(|| {
Pipeline::<_, __m256>::score_into(&striped, &pssm, &mut scores); Pipeline::<_, __m256>::score_into(&striped, &pssm, &mut scores);
test::black_box(scores.argmax()); test::black_box(Pipeline::<_, __m256>::best_position(&scores).unwrap());
}); });
} }
......
...@@ -30,11 +30,13 @@ mod seal { ...@@ -30,11 +30,13 @@ mod seal {
/// Generic trait for computing sequence scores with a PSSM. /// Generic trait for computing sequence scores with a PSSM.
pub trait Score<A: Alphabet, const K: usize, V: Vector, const C: usize> { pub trait Score<A: Alphabet, const K: usize, V: Vector, const C: usize> {
/// Compute the PSSM scores into the given buffer.
fn score_into<S, M>(seq: S, pssm: M, scores: &mut StripedScores<C>) fn score_into<S, M>(seq: S, pssm: M, scores: &mut StripedScores<C>)
where where
S: AsRef<StripedSequence<A, C>>, S: AsRef<StripedSequence<A, C>>,
M: AsRef<ScoringMatrix<A, K>>; M: AsRef<ScoringMatrix<A, K>>;
/// Compute the PSSM scores for every sequence positions.
fn score<S, M>(seq: S, pssm: M) -> StripedScores<C> fn score<S, M>(seq: S, pssm: M) -> StripedScores<C>
where where
S: AsRef<StripedSequence<A, C>>, S: AsRef<StripedSequence<A, C>>,
...@@ -44,6 +46,26 @@ pub trait Score<A: Alphabet, const K: usize, V: Vector, const C: usize> { ...@@ -44,6 +46,26 @@ pub trait Score<A: Alphabet, const K: usize, V: Vector, const C: usize> {
Self::score_into(seq, pssm, &mut scores); Self::score_into(seq, pssm, &mut scores);
scores scores
} }
/// Find the sequence position with the highest score.
fn best_position(scores: &StripedScores<C>) -> Option<usize> {
if scores.length == 0 {
return None;
}
let mut best_pos = 0;
let mut best_score = scores.data[0][0];
for i in 0..scores.length {
let col = i / scores.data.rows();
let row = i % scores.data.rows();
if scores.data[row][col] > best_score {
best_score = scores.data[row][col];
best_pos = i;
}
}
Some(best_pos)
}
} }
// --- Pipeline ---------------------------------------------------------------- // --- Pipeline ----------------------------------------------------------------
...@@ -299,6 +321,71 @@ unsafe fn score_avx2( ...@@ -299,6 +321,71 @@ unsafe fn score_avx2(
} }
} }
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn best_position_avx2(
scores: &StripedScores<{ std::mem::size_of::<__m256i>() }>,
) -> Option<usize> {
if scores.length == 0 {
None
} else {
unsafe {
// the row index for the best score in each column
// (these are 32-bit integers but for use with `_mm256_blendv_ps`
// they get stored in 32-bit float vectors).
let mut p1 = _mm256_setzero_ps();
let mut p2 = _mm256_setzero_ps();
let mut p3 = _mm256_setzero_ps();
let mut p4 = _mm256_setzero_ps();
// store the best scores for each column
let mut s1 = _mm256_loadu_ps(scores.data[0][0x00..].as_ptr());
let mut s2 = _mm256_loadu_ps(scores.data[0][0x04..].as_ptr());
let mut s3 = _mm256_loadu_ps(scores.data[0][0x08..].as_ptr());
let mut s4 = _mm256_loadu_ps(scores.data[0][0x0c..].as_ptr());
// process all rows iteratively
for (i, row) in scores.data.iter().enumerate() {
// record the current row index
let index = _mm256_castsi256_ps(_mm256_set1_epi32(i as i32));
// load scores for the current row
let r1 = _mm256_loadu_ps(row[0x00..].as_ptr());
let r2 = _mm256_loadu_ps(row[0x04..].as_ptr());
let r3 = _mm256_loadu_ps(row[0x08..].as_ptr());
let r4 = _mm256_loadu_ps(row[0x0c..].as_ptr());
// compare scores to local maximums
let c1 = _mm256_cmp_ps(s1, r1, _CMP_LT_OS);
let c2 = _mm256_cmp_ps(s2, r2, _CMP_LT_OS);
let c3 = _mm256_cmp_ps(s3, r3, _CMP_LT_OS);
let c4 = _mm256_cmp_ps(s4, r4, _CMP_LT_OS);
// replace indices of new local maximums
p1 = _mm256_blendv_ps(p1, index, c1);
p2 = _mm256_blendv_ps(p2, index, c2);
p3 = _mm256_blendv_ps(p3, index, c3);
p4 = _mm256_blendv_ps(p4, index, c4);
// replace values of new local maximums
s1 = _mm256_blendv_ps(s1, r1, c1);
s2 = _mm256_blendv_ps(s2, r2, c2);
s3 = _mm256_blendv_ps(s3, r3, c3);
s4 = _mm256_blendv_ps(s4, r4, c4);
}
// find the global maximum across all columns
let mut x: [u32; 32] = [0; 32];
_mm256_storeu_si256(x[0x00..].as_mut_ptr() as *mut _, _mm256_castps_si256(p1));
_mm256_storeu_si256(x[0x04..].as_mut_ptr() as *mut _, _mm256_castps_si256(p2));
_mm256_storeu_si256(x[0x08..].as_mut_ptr() as *mut _, _mm256_castps_si256(p3));
_mm256_storeu_si256(x[0x0c..].as_mut_ptr() as *mut _, _mm256_castps_si256(p4));
let mut best_pos = 0;
let mut best_score = -f32::INFINITY;
for (col, &row) in x.iter().enumerate() {
if scores.data[row as usize][col] > best_score {
best_score = scores.data[row as usize][col];
best_pos = col * scores.data.rows() + row as usize;
}
}
Some(best_pos)
}
}
}
/// Intel 256-bit vector implementation, for 32 elements column width. /// Intel 256-bit vector implementation, for 32 elements column width.
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
impl Score<Dna, { Dna::K }, __m256, { std::mem::size_of::<__m256i>() }> for Pipeline<Dna, __m256> { impl Score<Dna, { Dna::K }, __m256, { std::mem::size_of::<__m256i>() }> for Pipeline<Dna, __m256> {
...@@ -329,6 +416,10 @@ impl Score<Dna, { Dna::K }, __m256, { std::mem::size_of::<__m256i>() }> for Pipe ...@@ -329,6 +416,10 @@ impl Score<Dna, { Dna::K }, __m256, { std::mem::size_of::<__m256i>() }> for Pipe
score_avx2(seq, pssm, scores); score_avx2(seq, pssm, scores);
} }
} }
fn best_position(scores: &StripedScores<{ std::mem::size_of::<__m256i>() }>) -> Option<usize> {
unsafe { best_position_avx2(scores) }
}
} }
// --- StripedScores ----------------------------------------------------------- // --- StripedScores -----------------------------------------------------------
...@@ -403,25 +494,6 @@ impl<const C: usize> StripedScores<C> { ...@@ -403,25 +494,6 @@ impl<const C: usize> StripedScores<C> {
} }
vec vec
} }
/// Get the index of the highest scoring position, if any.
pub fn argmax(&self) -> Option<usize> {
if self.len() > 0 {
let mut best_pos = 0;
let mut best_score = self.data[0][0];
for i in 0..self.length {
let col = i / self.data.rows();
let row = i % self.data.rows();
if self.data[row][col] > best_score {
best_pos = i;
best_score = self.data[row][col];
}
}
Some(best_pos)
} else {
None
}
}
} }
impl<const C: usize> AsRef<DenseMatrix<f32, C>> for StripedScores<C> { impl<const C: usize> AsRef<DenseMatrix<f32, C>> for StripedScores<C> {
......
...@@ -65,6 +65,29 @@ fn test_score_generic() { ...@@ -65,6 +65,29 @@ fn test_score_generic() {
} }
} }
#[test]
fn test_best_position_generic() {
let encoded = EncodedSequence::<Dna>::from_str(SEQUENCE).unwrap();
let mut striped = encoded.to_striped::<2>();
let cm = CountMatrix::<Dna, { Dna::K }>::from_sequences(
PATTERNS
.iter()
.map(|x| EncodedSequence::from_str(x).unwrap()),
)
.unwrap();
let pbm = cm.to_freq([0.1, 0.1, 0.1, 0.1, 0.0]);
let pwm = pbm.to_weight(None);
let pssm = pwm.into();
striped.configure(&pssm);
let result = Pipeline::<Dna, f32>::score(&striped, &pssm);
assert_eq!(
<Pipeline::<Dna, f32> as Score<Dna, { Dna::K }, _, 2>>::best_position(&result),
Some(18)
);
}
#[cfg(target_feature = "ssse3")] #[cfg(target_feature = "ssse3")]
#[test] #[test]
fn test_score_ssse3() { fn test_score_ssse3() {
...@@ -136,3 +159,24 @@ fn test_score_avx2() { ...@@ -136,3 +159,24 @@ fn test_score_avx2() {
); );
} }
} }
#[cfg(target_feature = "avx2")]
#[test]
fn test_best_position_avx2() {
let encoded = EncodedSequence::<Dna>::from_str(SEQUENCE).unwrap();
let mut striped = encoded.to_striped();
let cm = CountMatrix::<Dna, { Dna::K }>::from_sequences(
PATTERNS
.iter()
.map(|x| EncodedSequence::from_str(x).unwrap()),
)
.unwrap();
let pbm = cm.to_freq([0.1, 0.1, 0.1, 0.1, 0.0]);
let pwm = pbm.to_weight(None);
let pssm = pwm.into();
striped.configure(&pssm);
let result = Pipeline::<Dna, __m256>::score(&striped, &pssm);
assert_eq!(Pipeline::<Dna, __m256>::best_position(&result), Some(18));
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment