From 579d9527f520907bd2aed7fbf6bdb72d3ebc2ceb Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Mon, 24 Jun 2024 16:02:11 +0200 Subject: [PATCH] Update `Avx2::argmax_u8` to use a single comparison --- lightmotif/src/pli/platform/avx2.rs | 40 ++++++++++++----------------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs index 40afaf2..36c31f8 100644 --- a/lightmotif/src/pli/platform/avx2.rs +++ b/lightmotif/src/pli/platform/avx2.rs @@ -461,6 +461,7 @@ unsafe fn argmax_u8_avx2( } else { let data = scores.matrix(); unsafe { + let ones = _mm256_set1_epi16(1); let mut dataptr = data[0].as_ptr(); // the row index for the best score in each column // (these are 32-bit integers but for use with `_mm256_blendv_ps` @@ -468,8 +469,8 @@ unsafe fn argmax_u8_avx2( let mut p1 = _mm256_setzero_si256(); let mut p2 = _mm256_setzero_si256(); // store the best scores for each column - let mut s1 = _mm256_setzero_si256(); - let mut s2 = _mm256_setzero_si256(); + let mut s1 = _mm256_set1_epi16(-1); + let mut s2 = _mm256_set1_epi16(-1); // process all rows iteratively for i in 0..data.rows() { // record the current row index @@ -483,37 +484,28 @@ unsafe fn argmax_u8_avx2( let r1 = _mm256_unpacklo_epi8(r, _mm256_setzero_si256()); let r2 = _mm256_unpackhi_epi8(r, _mm256_setzero_si256()); // compare scores to local maximums - let c1 = _mm256_or_si256(_mm256_cmpeq_epi16(r1, s1), _mm256_cmpgt_epi16(r1, s1)); - let c2 = _mm256_or_si256(_mm256_cmpeq_epi16(r2, s2), _mm256_cmpgt_epi16(r2, s2)); + let c1 = _mm256_cmpgt_epi16(r1, s1); + let c2 = _mm256_cmpgt_epi16(r2, s2); // replace indices of new local maximums p1 = _mm256_blendv_epi8(p1, index, c1); p2 = _mm256_blendv_epi8(p2, index, c2); - // replace values of new local maximums - s1 = _mm256_blendv_epi8(s1, r1, c1); - s2 = _mm256_blendv_epi8(s2, r2, c2); + // replace values of new local maximums (minus one, so that + // we can do a `_mm256_cmpgt_epi16` comparison instead of a + // `_mm256_cmpge_epi16` which doesn't exist on AVX2) + s1 = _mm256_blendv_epi8(s1, _mm256_sub_epi16(r1, ones), c1); + s2 = _mm256_blendv_epi8(s2, _mm256_sub_epi16(r2, ones), c2); // advance to next row dataptr = dataptr.add(data.stride()); } - // find the global maximum across all columns + // record the column-local maxima let mut x: [u16; 32] = [0; 32]; _mm256_storeu_si256(x.as_mut_ptr() as *mut _, p1); _mm256_storeu_si256(x[16..].as_mut_ptr() as *mut _, p2); - - // println!("{:?}", x); - - let mut best_pos = MatrixCoordinates::default(); - let mut best_score = data[best_pos]; - - for (col, &row) in x.iter().enumerate() { - let pos = MatrixCoordinates::new(row as usize, col); - let score = data[pos]; - if (score >= best_score) { - best_score = score; - best_pos = pos; - } - } - - Some(best_pos) + // find the global maximum across all columns + x.into_iter() + .enumerate() + .map(|(col, row)| MatrixCoordinates::new(row as usize, col)) + .max_by_key(|&pos| &data[pos]) } } } -- GitLab