Skip to content
Snippets Groups Projects
Commit 5f91f3f0 authored by Martin Larralde's avatar Martin Larralde
Browse files

Implement `Score<u8, Dna>` for the AVX2 platform

parent bbde5508
No related branches found
No related tags found
No related merge requests found
......@@ -184,7 +184,7 @@ impl<A: Alphabet> Maximum<f32, <Dispatch as Backend>::LANES> for Pipeline<A, Dis
) -> Option<MatrixCoordinates> {
match self.backend {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
Dispatch::Avx2 => Avx2::argmax(scores),
Dispatch::Avx2 => Avx2::argmax_f32(scores),
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
Dispatch::Sse2 => Sse2::argmax(scores),
_ => <Generic as Maximum<f32, <Dispatch as Backend>::LANES>>::argmax(&Generic, scores),
......@@ -198,8 +198,8 @@ impl<A: Alphabet> Maximum<u8, <Dispatch as Backend>::LANES> for Pipeline<A, Disp
scores: &StripedScores<u8, <Dispatch as Backend>::LANES>,
) -> Option<MatrixCoordinates> {
match self.backend {
// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
// Dispatch::Avx2 => Avx2::argmax(scores),
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
Dispatch::Avx2 => Avx2::argmax_u8(scores),
// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
// Dispatch::Sse2 => Sse2::argmax(scores),
_ => <Generic as Maximum<u8, <Dispatch as Backend>::LANES>>::argmax(&Generic, scores),
......@@ -208,3 +208,5 @@ impl<A: Alphabet> Maximum<u8, <Dispatch as Backend>::LANES> for Pipeline<A, Disp
}
impl<A: Alphabet> Threshold<f32, <Dispatch as Backend>::LANES> for Pipeline<A, Dispatch> {}
impl<A: Alphabet> Threshold<u8, <Dispatch as Backend>::LANES> for Pipeline<A, Dispatch> {}
......@@ -472,10 +472,18 @@ impl<A: Alphabet> Maximum<f32, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {
&self,
scores: &StripedScores<f32, <Avx2 as Backend>::LANES>,
) -> Option<MatrixCoordinates> {
Avx2::argmax(scores)
Avx2::argmax_f32(scores)
}
}
impl<A: Alphabet> Maximum<u8, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {
fn argmax(
&self,
scores: &StripedScores<u8, <Avx2 as Backend>::LANES>,
) -> Option<MatrixCoordinates> {
Avx2::argmax_u8(scores)
}
}
impl<A: Alphabet> Maximum<u8, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {}
impl<A: Alphabet> Threshold<f32, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {}
......
......@@ -316,7 +316,7 @@ pub unsafe fn score_u8_avx2_shuffle<A>(
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn argmax_avx2(
unsafe fn argmax_f32_avx2(
scores: &StripedScores<f32, <Avx2 as Backend>::LANES>,
) -> Option<MatrixCoordinates> {
if scores.max_index() > u32::MAX as usize {
......@@ -400,6 +400,78 @@ unsafe fn argmax_avx2(
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn argmax_u8_avx2(
scores: &StripedScores<u8, <Avx2 as Backend>::LANES>,
) -> Option<MatrixCoordinates> {
if scores.matrix().rows() > u16::MAX as usize + 1 {
panic!(
"This implementation only supports matrices with at most {} rows, found a sequence with {} rows. Contact the developers at https://github.com/althonos/lightmotif.",
u16::MAX, scores.matrix().rows()
);
} else if scores.is_empty() {
None
} else {
let data = scores.matrix();
unsafe {
let mut dataptr = data[0].as_ptr();
// the row index for the best score in each column
// (these are 32-bit integers but for use with `_mm256_blendv_ps`
// they get stored in 32-bit float vectors).
let mut p1 = _mm256_setzero_si256();
let mut p2 = _mm256_setzero_si256();
// store the best scores for each column
let mut s1 = _mm256_setzero_si256();
let mut s2 = _mm256_setzero_si256();
// process all rows iteratively
for i in 0..data.rows() {
// record the current row index
let index = _mm256_set1_epi16(i as i16);
// load scores for the current row
let r = _mm256_load_si256(dataptr as *const _);
// unpack scores into 16-bit vectors (we can't use 8-bit
// vectors directly because AVX2 doesn't support unsigned
// comparisons with 8-bit integers, so we need to translate
// them to signed comparisons in 16-bit space)
let r1 = _mm256_unpacklo_epi8(r, _mm256_setzero_si256());
let r2 = _mm256_unpackhi_epi8(r, _mm256_setzero_si256());
// compare scores to local maximums
let c1 = _mm256_or_si256(_mm256_cmpgt_epi16(r1, s1), _mm256_cmpeq_epi16(r1, s1));
let c2 = _mm256_or_si256(_mm256_cmpgt_epi16(r2, s2), _mm256_cmpeq_epi16(r2, s2));
// replace indices of new local maximums
p1 = _mm256_blendv_epi8(p1, index, c1);
p2 = _mm256_blendv_epi8(p2, index, c2);
// replace values of new local maximums
s1 = _mm256_blendv_epi8(s1, r1, c1);
s2 = _mm256_blendv_epi8(s2, r2, c2);
// advance to next row
dataptr = dataptr.add(data.stride());
}
// find the global maximum across all columns
let mut x: [u16; 32] = [0; 32];
_mm256_storeu_si256(x.as_mut_ptr() as *mut _, p1);
_mm256_storeu_si256(x[16..].as_mut_ptr() as *mut _, p2);
// println!("{:?}", x);
let mut best_pos = MatrixCoordinates::default();
let mut best_score = data[best_pos];
for (col, &row) in x.iter().enumerate() {
let pos = MatrixCoordinates::new(row as usize, col);
let score = data[pos];
if (score >= best_score) {
best_score = score;
best_pos = pos;
}
}
Some(best_pos)
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn stripe_avx2<A>(
......@@ -765,12 +837,24 @@ impl Avx2 {
}
#[allow(unused)]
pub fn argmax(
pub fn argmax_f32(
scores: &StripedScores<f32, <Avx2 as Backend>::LANES>,
) -> Option<MatrixCoordinates> {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe {
argmax_avx2(scores)
argmax_f32_avx2(scores)
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
panic!("attempting to run AVX2 code on a non-x86 host")
}
#[allow(unused)]
pub fn argmax_u8(
scores: &StripedScores<u8, <Avx2 as Backend>::LANES>,
) -> Option<MatrixCoordinates> {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe {
argmax_u8_avx2(scores)
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
panic!("attempting to run AVX2 code on a non-x86 host")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment