From 4132b956e2655bdd56726381e9bbcdd8a6648149 Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Mon, 24 Jun 2024 14:08:52 +0200 Subject: [PATCH] Add dedicated `Maximum::<u8, _>::max` implementation for AVX2 --- lightmotif/src/pli/dispatch.rs | 10 ++++++++ lightmotif/src/pli/mod.rs | 4 +++ lightmotif/src/pli/platform/avx2.rs | 38 +++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/lightmotif/src/pli/dispatch.rs b/lightmotif/src/pli/dispatch.rs index 7fa8add..1dc505c 100644 --- a/lightmotif/src/pli/dispatch.rs +++ b/lightmotif/src/pli/dispatch.rs @@ -205,6 +205,16 @@ impl<A: Alphabet> Maximum<u8, <Dispatch as Backend>::LANES> for Pipeline<A, Disp _ => <Generic as Maximum<u8, <Dispatch as Backend>::LANES>>::argmax(&Generic, scores), } } + + fn max(&self, scores: &StripedScores<u8, <Dispatch as Backend>::LANES>) -> Option<u8> { + match self.backend { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + Dispatch::Avx2 => Avx2::max_u8(scores), + // #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + // Dispatch::Sse2 => Sse2::argmax(scores), + _ => <Generic as Maximum<u8, <Dispatch as Backend>::LANES>>::max(&Generic, scores), + } + } } impl<A: Alphabet> Threshold<f32, <Dispatch as Backend>::LANES> for Pipeline<A, Dispatch> {} diff --git a/lightmotif/src/pli/mod.rs b/lightmotif/src/pli/mod.rs index 80a053b..83f9bc1 100644 --- a/lightmotif/src/pli/mod.rs +++ b/lightmotif/src/pli/mod.rs @@ -492,6 +492,10 @@ impl<A: Alphabet> Maximum<u8, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> { ) -> Option<MatrixCoordinates> { Avx2::argmax_u8(scores) } + + fn max(&self, scores: &StripedScores<u8, <Avx2 as Backend>::LANES>) -> Option<u8> { + Avx2::max_u8(scores) + } } impl<A: Alphabet> Threshold<f32, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {} diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs index 02bc8d0..3df2136 100644 --- a/lightmotif/src/pli/platform/avx2.rs +++ b/lightmotif/src/pli/platform/avx2.rs @@ -477,6 +477,34 @@ unsafe fn argmax_u8_avx2( } } +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[target_feature(enable = "avx2")] +unsafe fn max_u8_avx2(scores: &StripedScores<u8, <Avx2 as Backend>::LANES>) -> Option<u8> { + if scores.is_empty() { + None + } else { + let data = scores.matrix(); + unsafe { + let mut dataptr = data[0].as_ptr(); + // store the best scores for each column + let mut m = _mm256_setzero_si256(); + // process all rows iteratively + for i in 0..data.rows() { + // load scores for the current row + let r = _mm256_load_si256(dataptr as *const _); + // find highest score + m = _mm256_max_epu8(m, r); + // advance to next row + dataptr = dataptr.add(data.stride()); + } + // find the global maximum across all columns + let mut x: [u8; 32] = [0; 32]; + _mm256_storeu_si256(x.as_mut_ptr() as *mut _, m); + x.into_iter().max() + } + } +} + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[target_feature(enable = "avx2")] unsafe fn stripe_avx2<A>( @@ -865,6 +893,16 @@ impl Avx2 { panic!("attempting to run AVX2 code on a non-x86 host") } + #[allow(unused)] + pub fn max_u8(scores: &StripedScores<u8, <Avx2 as Backend>::LANES>) -> Option<u8> { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + unsafe { + max_u8_avx2(scores) + } + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] + panic!("attempting to run AVX2 code on a non-x86 host") + } + #[allow(unused)] pub fn stripe_into<A, S>(seq: S, matrix: &mut StripedSequence<A, <Avx2 as Backend>::LANES>) where -- GitLab