Skip to content
Snippets Groups Projects
Commit b4408fc4 authored by Martin Larralde's avatar Martin Larralde
Browse files

Add SSSE3 implementation of the maximum score search algorithm

parent 7f70de3b
No related branches found
No related tags found
No related merge requests found
......@@ -48,9 +48,9 @@ fn bench_generic(bencher: &mut test::Bencher) {
});
}
#[cfg(target_feature = "avx2")]
#[cfg(target_feature = "ssse3")]
#[bench]
fn bench_sse2(bencher: &mut test::Bencher) {
fn bench_ssse3(bencher: &mut test::Bencher) {
let seq = &SEQUENCE[..10000];
let encoded = EncodedSequence::<Dna>::from_str(seq).unwrap();
let mut striped = encoded.to_striped();
......@@ -69,7 +69,7 @@ fn bench_sse2(bencher: &mut test::Bencher) {
bencher.bytes = seq.len() as u64;
bencher.iter(|| {
Pipeline::<_, __m128>::score_into(&striped, &pssm, &mut scores);
test::black_box(Pipeline::<_, __m128>::best_position(&scores));
test::black_box(Pipeline::<_, __m128>::best_position(&scores).unwrap());
});
}
......
......@@ -175,10 +175,79 @@ unsafe fn score_ssse3(
}
// record the score for the current position
let row = &mut scores.data[i];
_mm_store_ps(row[0..].as_mut_ptr(), s1);
_mm_store_ps(row[4..].as_mut_ptr(), s2);
_mm_store_ps(row[8..].as_mut_ptr(), s3);
_mm_store_ps(row[12..].as_mut_ptr(), s4);
_mm_storeu_ps(row[0..].as_mut_ptr(), s1);
_mm_storeu_ps(row[4..].as_mut_ptr(), s2);
_mm_storeu_ps(row[8..].as_mut_ptr(), s3);
_mm_storeu_ps(row[12..].as_mut_ptr(), s4);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "ssse3")]
unsafe fn best_position_ssse3(
scores: &StripedScores<{ std::mem::size_of::<__m128i>() }>,
) -> Option<usize> {
if scores.length == 0 {
None
} else {
unsafe {
// the row index for the best score in each column
// (these are 32-bit integers but for use with `_mm256_blendv_ps`
// they get stored in 32-bit float vectors).
let mut p1 = _mm_setzero_ps();
let mut p2 = _mm_setzero_ps();
let mut p3 = _mm_setzero_ps();
let mut p4 = _mm_setzero_ps();
// store the best scores for each column
let mut s1 = _mm_loadu_ps(scores.data[0][0x00..].as_ptr());
let mut s2 = _mm_loadu_ps(scores.data[0][0x04..].as_ptr());
let mut s3 = _mm_loadu_ps(scores.data[0][0x08..].as_ptr());
let mut s4 = _mm_loadu_ps(scores.data[0][0x0c..].as_ptr());
// process all rows iteratively
for (i, row) in scores.data.iter().enumerate() {
// record the current row index
let index = _mm_castsi128_ps(_mm_set1_epi32(i as i32));
// load scores for the current row
let r1 = _mm_loadu_ps(row[0x00..].as_ptr());
let r2 = _mm_loadu_ps(row[0x04..].as_ptr());
let r3 = _mm_loadu_ps(row[0x08..].as_ptr());
let r4 = _mm_loadu_ps(row[0x0c..].as_ptr());
// compare scores to local maximums
let c1 = _mm_cmplt_ps(s1, r1);
let c2 = _mm_cmplt_ps(s2, r2);
let c3 = _mm_cmplt_ps(s3, r3);
let c4 = _mm_cmplt_ps(s4, r4);
// NOTE: code below could use `_mm_blendv_ps` instead,
// but this instruction is only available on SSE4.1
// while the rest of the code is actually using at
// most SSSE3 instructions.
// replace indices of new local maximums
p1 = _mm_or_ps(_mm_andnot_ps(p1, c1), _mm_and_ps(index, c1));
p2 = _mm_or_ps(_mm_andnot_ps(p2, c2), _mm_and_ps(index, c2));
p3 = _mm_or_ps(_mm_andnot_ps(p3, c3), _mm_and_ps(index, c3));
p4 = _mm_or_ps(_mm_andnot_ps(p4, c4), _mm_and_ps(index, c4));
// replace values of new local maximums
s1 = _mm_or_ps(_mm_andnot_ps(s1, c1), _mm_and_ps(r1, c1));
s2 = _mm_or_ps(_mm_andnot_ps(s2, c2), _mm_and_ps(r2, c2));
s3 = _mm_or_ps(_mm_andnot_ps(s3, c3), _mm_and_ps(r3, c3));
s4 = _mm_or_ps(_mm_andnot_ps(s4, c4), _mm_and_ps(r4, c4));
}
// find the global maximum across all columns
let mut x: [u32; 16] = [0; 16];
_mm_storeu_si128(x[0x00..].as_mut_ptr() as *mut _, _mm_castps_si128(p1));
_mm_storeu_si128(x[0x04..].as_mut_ptr() as *mut _, _mm_castps_si128(p2));
_mm_storeu_si128(x[0x08..].as_mut_ptr() as *mut _, _mm_castps_si128(p3));
_mm_storeu_si128(x[0x0c..].as_mut_ptr() as *mut _, _mm_castps_si128(p4));
let mut best_pos = 0;
let mut best_score = -f32::INFINITY;
for (col, &row) in x.iter().enumerate() {
if scores.data[row as usize][col] > best_score {
best_score = scores.data[row as usize][col];
best_pos = col * scores.data.rows() + row as usize;
}
}
Some(best_pos)
}
}
}
......@@ -211,6 +280,10 @@ impl Score<Dna, { Dna::K }, __m128, { std::mem::size_of::<__m128i>() }> for Pipe
score_ssse3(seq, pssm, scores);
}
}
fn best_position(scores: &StripedScores<{ std::mem::size_of::<__m128i>() }>) -> Option<usize> {
unsafe { best_position_ssse3(scores) }
}
}
// --- AVX2 --------------------------------------------------------------------
......@@ -338,19 +411,19 @@ unsafe fn best_position_avx2(
let mut p3 = _mm256_setzero_ps();
let mut p4 = _mm256_setzero_ps();
// store the best scores for each column
let mut s1 = _mm256_loadu_ps(scores.data[0][0x00..].as_ptr());
let mut s2 = _mm256_loadu_ps(scores.data[0][0x04..].as_ptr());
let mut s3 = _mm256_loadu_ps(scores.data[0][0x08..].as_ptr());
let mut s4 = _mm256_loadu_ps(scores.data[0][0x0c..].as_ptr());
let mut s1 = _mm256_load_ps(scores.data[0][0x00..].as_ptr());
let mut s2 = _mm256_load_ps(scores.data[0][0x08..].as_ptr());
let mut s3 = _mm256_load_ps(scores.data[0][0x10..].as_ptr());
let mut s4 = _mm256_load_ps(scores.data[0][0x18..].as_ptr());
// process all rows iteratively
for (i, row) in scores.data.iter().enumerate() {
// record the current row index
let index = _mm256_castsi256_ps(_mm256_set1_epi32(i as i32));
// load scores for the current row
let r1 = _mm256_loadu_ps(row[0x00..].as_ptr());
let r2 = _mm256_loadu_ps(row[0x04..].as_ptr());
let r3 = _mm256_loadu_ps(row[0x08..].as_ptr());
let r4 = _mm256_loadu_ps(row[0x0c..].as_ptr());
let r1 = _mm256_load_ps(row[0x00..].as_ptr());
let r2 = _mm256_load_ps(row[0x08..].as_ptr());
let r3 = _mm256_load_ps(row[0x10..].as_ptr());
let r4 = _mm256_load_ps(row[0x18..].as_ptr());
// compare scores to local maximums
let c1 = _mm256_cmp_ps(s1, r1, _CMP_LT_OS);
let c2 = _mm256_cmp_ps(s2, r2, _CMP_LT_OS);
......@@ -370,9 +443,9 @@ unsafe fn best_position_avx2(
// find the global maximum across all columns
let mut x: [u32; 32] = [0; 32];
_mm256_storeu_si256(x[0x00..].as_mut_ptr() as *mut _, _mm256_castps_si256(p1));
_mm256_storeu_si256(x[0x04..].as_mut_ptr() as *mut _, _mm256_castps_si256(p2));
_mm256_storeu_si256(x[0x08..].as_mut_ptr() as *mut _, _mm256_castps_si256(p3));
_mm256_storeu_si256(x[0x0c..].as_mut_ptr() as *mut _, _mm256_castps_si256(p4));
_mm256_storeu_si256(x[0x08..].as_mut_ptr() as *mut _, _mm256_castps_si256(p2));
_mm256_storeu_si256(x[0x10..].as_mut_ptr() as *mut _, _mm256_castps_si256(p3));
_mm256_storeu_si256(x[0x18..].as_mut_ptr() as *mut _, _mm256_castps_si256(p4));
let mut best_pos = 0;
let mut best_score = -f32::INFINITY;
for (col, &row) in x.iter().enumerate() {
......
......@@ -76,7 +76,7 @@ fn test_best_position_generic() {
.map(|x| EncodedSequence::from_str(x).unwrap()),
)
.unwrap();
let pbm = cm.to_freq([0.1, 0.1, 0.1, 0.1, 0.0]);
let pbm = cm.to_freq(0.1);
let pwm = pbm.to_weight(None);
let pssm = pwm.into();
......@@ -124,6 +124,27 @@ fn test_score_ssse3() {
}
}
#[cfg(target_feature = "ssse3")]
#[test]
fn test_best_position_ssse3() {
let encoded = EncodedSequence::<Dna>::from_str(SEQUENCE).unwrap();
let mut striped = encoded.to_striped();
let cm = CountMatrix::<Dna, { Dna::K }>::from_sequences(
PATTERNS
.iter()
.map(|x| EncodedSequence::from_str(x).unwrap()),
)
.unwrap();
let pbm = cm.to_freq(0.1);
let pwm = pbm.to_weight(None);
let pssm = pwm.into();
striped.configure(&pssm);
let result = Pipeline::<Dna, __m128>::score(&striped, &pssm);
assert_eq!(Pipeline::<Dna, __m128>::best_position(&result), Some(18));
}
#[cfg(target_feature = "avx2")]
#[test]
fn test_score_avx2() {
......@@ -136,7 +157,7 @@ fn test_score_avx2() {
.map(|x| EncodedSequence::from_str(x).unwrap()),
)
.unwrap();
let pbm = cm.to_freq([0.1, 0.1, 0.1, 0.1, 0.0]);
let pbm = cm.to_freq(0.1);
let pwm = pbm.to_weight(None);
let pssm = pwm.into();
......@@ -144,10 +165,6 @@ fn test_score_avx2() {
let result = Pipeline::<_, __m256>::score(&striped, &pssm);
let scores = result.to_vec();
// for i in 0..result.data.rows() {
// println!("{:?}", &result.data[i]);
// }
assert_eq!(scores.len(), EXPECTED.len());
for i in 0..EXPECTED.len() {
assert!(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment