diff --git a/lightmotif/src/pli/dispatch.rs b/lightmotif/src/pli/dispatch.rs index 28790888252185e9f4c9e08b98a508e583aedf6f..63514133819fd8908e8ca35934e8643c479734af 100644 --- a/lightmotif/src/pli/dispatch.rs +++ b/lightmotif/src/pli/dispatch.rs @@ -184,7 +184,7 @@ impl<A: Alphabet> Maximum<f32, <Dispatch as Backend>::LANES> for Pipeline<A, Dis ) -> Option<MatrixCoordinates> { match self.backend { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - Dispatch::Avx2 => Avx2::argmax(scores), + Dispatch::Avx2 => Avx2::argmax_f32(scores), #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] Dispatch::Sse2 => Sse2::argmax(scores), _ => <Generic as Maximum<f32, <Dispatch as Backend>::LANES>>::argmax(&Generic, scores), @@ -198,8 +198,8 @@ impl<A: Alphabet> Maximum<u8, <Dispatch as Backend>::LANES> for Pipeline<A, Disp scores: &StripedScores<u8, <Dispatch as Backend>::LANES>, ) -> Option<MatrixCoordinates> { match self.backend { - // #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - // Dispatch::Avx2 => Avx2::argmax(scores), + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + Dispatch::Avx2 => Avx2::argmax_u8(scores), // #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] // Dispatch::Sse2 => Sse2::argmax(scores), _ => <Generic as Maximum<u8, <Dispatch as Backend>::LANES>>::argmax(&Generic, scores), @@ -208,3 +208,5 @@ impl<A: Alphabet> Maximum<u8, <Dispatch as Backend>::LANES> for Pipeline<A, Disp } impl<A: Alphabet> Threshold<f32, <Dispatch as Backend>::LANES> for Pipeline<A, Dispatch> {} + +impl<A: Alphabet> Threshold<u8, <Dispatch as Backend>::LANES> for Pipeline<A, Dispatch> {} diff --git a/lightmotif/src/pli/mod.rs b/lightmotif/src/pli/mod.rs index 9c3a510eaf4a5845975116c62f4050d5262f0f91..e53a32d8fafb33acc5f649f6db28d30da9059d2d 100644 --- a/lightmotif/src/pli/mod.rs +++ b/lightmotif/src/pli/mod.rs @@ -472,10 +472,18 @@ impl<A: Alphabet> Maximum<f32, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> { &self, scores: &StripedScores<f32, <Avx2 as Backend>::LANES>, ) -> Option<MatrixCoordinates> { - Avx2::argmax(scores) + Avx2::argmax_f32(scores) + } +} + +impl<A: Alphabet> Maximum<u8, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> { + fn argmax( + &self, + scores: &StripedScores<u8, <Avx2 as Backend>::LANES>, + ) -> Option<MatrixCoordinates> { + Avx2::argmax_u8(scores) } } -impl<A: Alphabet> Maximum<u8, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {} impl<A: Alphabet> Threshold<f32, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {} diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs index 7a1dc3a328515cafd82adb03a4d3b776b5fb1f2c..c2bce4bc16d85822ab6cd8ec927d9a4b3fe86d11 100644 --- a/lightmotif/src/pli/platform/avx2.rs +++ b/lightmotif/src/pli/platform/avx2.rs @@ -316,7 +316,7 @@ pub unsafe fn score_u8_avx2_shuffle<A>( #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[target_feature(enable = "avx2")] -unsafe fn argmax_avx2( +unsafe fn argmax_f32_avx2( scores: &StripedScores<f32, <Avx2 as Backend>::LANES>, ) -> Option<MatrixCoordinates> { if scores.max_index() > u32::MAX as usize { @@ -400,6 +400,78 @@ unsafe fn argmax_avx2( } } +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[target_feature(enable = "avx2")] +unsafe fn argmax_u8_avx2( + scores: &StripedScores<u8, <Avx2 as Backend>::LANES>, +) -> Option<MatrixCoordinates> { + if scores.matrix().rows() > u16::MAX as usize + 1 { + panic!( + "This implementation only supports matrices with at most {} rows, found a sequence with {} rows. Contact the developers at https://github.com/althonos/lightmotif.", + u16::MAX, scores.matrix().rows() + ); + } else if scores.is_empty() { + None + } else { + let data = scores.matrix(); + unsafe { + let mut dataptr = data[0].as_ptr(); + // the row index for the best score in each column + // (these are 32-bit integers but for use with `_mm256_blendv_ps` + // they get stored in 32-bit float vectors). + let mut p1 = _mm256_setzero_si256(); + let mut p2 = _mm256_setzero_si256(); + // store the best scores for each column + let mut s1 = _mm256_setzero_si256(); + let mut s2 = _mm256_setzero_si256(); + // process all rows iteratively + for i in 0..data.rows() { + // record the current row index + let index = _mm256_set1_epi16(i as i16); + // load scores for the current row + let r = _mm256_load_si256(dataptr as *const _); + // unpack scores into 16-bit vectors (we can't use 8-bit + // vectors directly because AVX2 doesn't support unsigned + // comparisons with 8-bit integers, so we need to translate + // them to signed comparisons in 16-bit space) + let r1 = _mm256_unpacklo_epi8(r, _mm256_setzero_si256()); + let r2 = _mm256_unpackhi_epi8(r, _mm256_setzero_si256()); + // compare scores to local maximums + let c1 = _mm256_or_si256(_mm256_cmpgt_epi16(r1, s1), _mm256_cmpeq_epi16(r1, s1)); + let c2 = _mm256_or_si256(_mm256_cmpgt_epi16(r2, s2), _mm256_cmpeq_epi16(r2, s2)); + // replace indices of new local maximums + p1 = _mm256_blendv_epi8(p1, index, c1); + p2 = _mm256_blendv_epi8(p2, index, c2); + // replace values of new local maximums + s1 = _mm256_blendv_epi8(s1, r1, c1); + s2 = _mm256_blendv_epi8(s2, r2, c2); + // advance to next row + dataptr = dataptr.add(data.stride()); + } + // find the global maximum across all columns + let mut x: [u16; 32] = [0; 32]; + _mm256_storeu_si256(x.as_mut_ptr() as *mut _, p1); + _mm256_storeu_si256(x[16..].as_mut_ptr() as *mut _, p2); + + // println!("{:?}", x); + + let mut best_pos = MatrixCoordinates::default(); + let mut best_score = data[best_pos]; + + for (col, &row) in x.iter().enumerate() { + let pos = MatrixCoordinates::new(row as usize, col); + let score = data[pos]; + if (score >= best_score) { + best_score = score; + best_pos = pos; + } + } + + Some(best_pos) + } + } +} + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[target_feature(enable = "avx2")] unsafe fn stripe_avx2<A>( @@ -765,12 +837,24 @@ impl Avx2 { } #[allow(unused)] - pub fn argmax( + pub fn argmax_f32( scores: &StripedScores<f32, <Avx2 as Backend>::LANES>, ) -> Option<MatrixCoordinates> { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] unsafe { - argmax_avx2(scores) + argmax_f32_avx2(scores) + } + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] + panic!("attempting to run AVX2 code on a non-x86 host") + } + + #[allow(unused)] + pub fn argmax_u8( + scores: &StripedScores<u8, <Avx2 as Backend>::LANES>, + ) -> Option<MatrixCoordinates> { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + unsafe { + argmax_u8_avx2(scores) } #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] panic!("attempting to run AVX2 code on a non-x86 host")