Skip to content
Snippets Groups Projects
Commit dae72f3f authored by Martin Larralde's avatar Martin Larralde
Browse files

Add dedicated `Maximum::<f32, _>::max` implementation for AVX2

parent 0a587d21
No related branches found
No related tags found
No related merge requests found
......@@ -190,6 +190,16 @@ impl<A: Alphabet> Maximum<f32, <Dispatch as Backend>::LANES> for Pipeline<A, Dis
_ => <Generic as Maximum<f32, <Dispatch as Backend>::LANES>>::argmax(&Generic, scores),
}
}
fn max(&self, scores: &StripedScores<f32, <Dispatch as Backend>::LANES>) -> Option<f32> {
match self.backend {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
Dispatch::Avx2 => Avx2::max_f32(scores),
// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
// Dispatch::Sse2 => Sse2::argmax(scores),
_ => <Generic as Maximum<f32, <Dispatch as Backend>::LANES>>::max(&Generic, scores),
}
}
}
impl<A: Alphabet> Maximum<u8, <Dispatch as Backend>::LANES> for Pipeline<A, Dispatch> {
......
......@@ -483,6 +483,10 @@ impl<A: Alphabet> Maximum<f32, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {
) -> Option<MatrixCoordinates> {
Avx2::argmax_f32(scores)
}
fn max(&self, scores: &StripedScores<f32, <Avx2 as Backend>::LANES>) -> Option<f32> {
Avx2::max_f32(scores)
}
}
impl<A: Alphabet> Maximum<u8, <Avx2 as Backend>::LANES> for Pipeline<A, Avx2> {
......
......@@ -405,6 +405,47 @@ unsafe fn argmax_f32_avx2(
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn max_f32_avx2(scores: &StripedScores<f32, <Avx2 as Backend>::LANES>) -> Option<f32> {
if scores.is_empty() {
None
} else {
let data = scores.matrix();
unsafe {
let mut dataptr = data[0].as_ptr();
// store the best scores for each column
let mut m1 = _mm256_setzero_ps();
let mut m2 = _mm256_setzero_ps();
let mut m3 = _mm256_setzero_ps();
let mut m4 = _mm256_setzero_ps();
// process all rows iteratively
for i in 0..data.rows() {
// load scores for the current row
let r1 = _mm256_load_ps(dataptr as *const _);
let r2 = _mm256_load_ps(dataptr.add(0x08) as *const _);
let r3 = _mm256_load_ps(dataptr.add(0x10) as *const _);
let r4 = _mm256_load_ps(dataptr.add(0x18) as *const _);
// find highest score
m1 = _mm256_max_ps(m1, r1);
m2 = _mm256_max_ps(m2, r2);
m3 = _mm256_max_ps(m3, r3);
m4 = _mm256_max_ps(m4, r4);
// advance to next row
dataptr = dataptr.add(data.stride());
}
//
let m = _mm256_max_ps(_mm256_max_ps(m1, m2), _mm256_max_ps(m3, m4));
// find the global maximum across all columns
let mut x: [f32; 8] = [0.0; 8];
_mm256_storeu_ps(x.as_mut_ptr() as *mut _, m);
x.into_iter().reduce(f32::max)
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
unsafe fn argmax_u8_avx2(
......@@ -881,6 +922,16 @@ impl Avx2 {
panic!("attempting to run AVX2 code on a non-x86 host")
}
#[allow(unused)]
pub fn max_f32(scores: &StripedScores<f32, <Avx2 as Backend>::LANES>) -> Option<f32> {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe {
max_f32_avx2(scores)
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
panic!("attempting to run AVX2 code on a non-x86 host")
}
#[allow(unused)]
pub fn argmax_u8(
scores: &StripedScores<u8, <Avx2 as Backend>::LANES>,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment