diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs
index 40afaf27aa16ef25e9039b0bd06ab3d53d96ce9a..36c31f8e65f166c622bc25162992f63c41f7c1cf 100644
--- a/lightmotif/src/pli/platform/avx2.rs
+++ b/lightmotif/src/pli/platform/avx2.rs
@@ -461,6 +461,7 @@ unsafe fn argmax_u8_avx2(
     } else {
         let data = scores.matrix();
         unsafe {
+            let ones = _mm256_set1_epi16(1);
             let mut dataptr = data[0].as_ptr();
             // the row index for the best score in each column
             // (these are 32-bit integers but for use with `_mm256_blendv_ps`
@@ -468,8 +469,8 @@ unsafe fn argmax_u8_avx2(
             let mut p1 = _mm256_setzero_si256();
             let mut p2 = _mm256_setzero_si256();
             // store the best scores for each column
-            let mut s1 = _mm256_setzero_si256();
-            let mut s2 = _mm256_setzero_si256();
+            let mut s1 = _mm256_set1_epi16(-1);
+            let mut s2 = _mm256_set1_epi16(-1);
             // process all rows iteratively
             for i in 0..data.rows() {
                 // record the current row index
@@ -483,37 +484,28 @@ unsafe fn argmax_u8_avx2(
                 let r1 = _mm256_unpacklo_epi8(r, _mm256_setzero_si256());
                 let r2 = _mm256_unpackhi_epi8(r, _mm256_setzero_si256());
                 // compare scores to local maximums
-                let c1 = _mm256_or_si256(_mm256_cmpeq_epi16(r1, s1), _mm256_cmpgt_epi16(r1, s1));
-                let c2 = _mm256_or_si256(_mm256_cmpeq_epi16(r2, s2), _mm256_cmpgt_epi16(r2, s2));
+                let c1 = _mm256_cmpgt_epi16(r1, s1);
+                let c2 = _mm256_cmpgt_epi16(r2, s2);
                 // replace indices of new local maximums
                 p1 = _mm256_blendv_epi8(p1, index, c1);
                 p2 = _mm256_blendv_epi8(p2, index, c2);
-                // replace values of new local maximums
-                s1 = _mm256_blendv_epi8(s1, r1, c1);
-                s2 = _mm256_blendv_epi8(s2, r2, c2);
+                // replace values of new local maximums (minus one, so that
+                // we can do a `_mm256_cmpgt_epi16` comparison instead of a
+                // `_mm256_cmpge_epi16` which doesn't exist on AVX2)
+                s1 = _mm256_blendv_epi8(s1, _mm256_sub_epi16(r1, ones), c1);
+                s2 = _mm256_blendv_epi8(s2, _mm256_sub_epi16(r2, ones), c2);
                 // advance to next row
                 dataptr = dataptr.add(data.stride());
             }
-            // find the global maximum across all columns
+            // record the column-local maxima
             let mut x: [u16; 32] = [0; 32];
             _mm256_storeu_si256(x.as_mut_ptr() as *mut _, p1);
             _mm256_storeu_si256(x[16..].as_mut_ptr() as *mut _, p2);
-
-            // println!("{:?}", x);
-
-            let mut best_pos = MatrixCoordinates::default();
-            let mut best_score = data[best_pos];
-
-            for (col, &row) in x.iter().enumerate() {
-                let pos = MatrixCoordinates::new(row as usize, col);
-                let score = data[pos];
-                if (score >= best_score) {
-                    best_score = score;
-                    best_pos = pos;
-                }
-            }
-
-            Some(best_pos)
+            // find the global maximum across all columns
+            x.into_iter()
+                .enumerate()
+                .map(|(col, row)| MatrixCoordinates::new(row as usize, col))
+                .max_by_key(|&pos| &data[pos])
         }
     }
 }