From dae72f3fc4d3a724fef0e90f1689ce4e3976de2e Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Mon, 24 Jun 2024 15:17:16 +0200 Subject: [PATCH] Add dedicated `Maximum::<f32, _>::max` implementation for AVX2 --- lightmotif/src/pli/dispatch.rs | 10 ++++++ lightmotif/src/pli/mod.rs | 4 +++ lightmotif/src/pli/platform/avx2.rs | 51 +++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/lightmotif/src/pli/dispatch.rs b/lightmotif/src/pli/dispatch.rs index 1dc505c..fd4aed6 100644 --- a/lightmotif/src/pli/dispatch.rs +++ b/lightmotif/src/pli/dispatch.rs @@ -190,6 +190,16 @@ impl<A: Alphabet> Maximum<f32, <Dispatch as Backend>::LANES> for Pipeline<A, Dis _ => <Generic as Maximum<f32, <Dispatch as Backend>::LANES>>::argmax(&Generic, scores), } } + + fn max(&self, scores: &StripedScores<f32, <Dispatch as Backend>::LANES>) -> Option<f32> { + match self.backend { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + Dispatch::Avx2 => Avx2::max_f32(scores), + // #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + // Dispatch::Sse2 => Sse2::argmax(scores), + _ => <Generic as Maximum<f32, <Dispatch as Backend>::LANES>>::max(&Generic, scores), + } + } } impl<A: Alphabet> Maximum<u8, <Dispatch as Backend>::LANES> for Pipeline<A, Dispatch> { diff --git a/lightmotif/src/pli/mod.rs b/lightmotif/src/pli/mod.rs index 83f9bc1..f594b91 100644 --- a/lightmotif/src/pli/mod.rs +++ b/lightmotif/src/pli/mod.rs @@ -483,6 +483,10 @@ impl<A: Alphabet> Maximum<f32, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> { ) -> Option<MatrixCoordinates> { Avx2::argmax_f32(scores) } + + fn max(&self, scores: &StripedScores<f32, <Avx2 as Backend>::LANES>) -> Option<f32> { + Avx2::max_f32(scores) + } } impl<A: Alphabet> Maximum<u8, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> { diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs index 3df2136..40afaf2 100644 --- a/lightmotif/src/pli/platform/avx2.rs +++ b/lightmotif/src/pli/platform/avx2.rs @@ -405,6 +405,47 @@ unsafe fn argmax_f32_avx2( } } +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[target_feature(enable = "avx2")] +unsafe fn max_f32_avx2(scores: &StripedScores<f32, <Avx2 as Backend>::LANES>) -> Option<f32> { + if scores.is_empty() { + None + } else { + let data = scores.matrix(); + unsafe { + let mut dataptr = data[0].as_ptr(); + // store the best scores for each column + let mut m1 = _mm256_setzero_ps(); + let mut m2 = _mm256_setzero_ps(); + let mut m3 = _mm256_setzero_ps(); + let mut m4 = _mm256_setzero_ps(); + // process all rows iteratively + for i in 0..data.rows() { + // load scores for the current row + let r1 = _mm256_load_ps(dataptr as *const _); + let r2 = _mm256_load_ps(dataptr.add(0x08) as *const _); + let r3 = _mm256_load_ps(dataptr.add(0x10) as *const _); + let r4 = _mm256_load_ps(dataptr.add(0x18) as *const _); + // find highest score + m1 = _mm256_max_ps(m1, r1); + m2 = _mm256_max_ps(m2, r2); + m3 = _mm256_max_ps(m3, r3); + m4 = _mm256_max_ps(m4, r4); + // advance to next row + dataptr = dataptr.add(data.stride()); + } + + // + let m = _mm256_max_ps(_mm256_max_ps(m1, m2), _mm256_max_ps(m3, m4)); + + // find the global maximum across all columns + let mut x: [f32; 8] = [0.0; 8]; + _mm256_storeu_ps(x.as_mut_ptr() as *mut _, m); + x.into_iter().reduce(f32::max) + } + } +} + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] #[target_feature(enable = "avx2")] unsafe fn argmax_u8_avx2( @@ -881,6 +922,16 @@ impl Avx2 { panic!("attempting to run AVX2 code on a non-x86 host") } + #[allow(unused)] + pub fn max_f32(scores: &StripedScores<f32, <Avx2 as Backend>::LANES>) -> Option<f32> { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + unsafe { + max_f32_avx2(scores) + } + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] + panic!("attempting to run AVX2 code on a non-x86 host") + } + #[allow(unused)] pub fn argmax_u8( scores: &StripedScores<u8, <Avx2 as Backend>::LANES>, -- GitLab