Skip to content
Snippets Groups Projects
Commit 46297c96 authored by Martin Larralde's avatar Martin Larralde
Browse files

Use AVX2 masked stores instead of blend in `Avx2::threshold`

parent e2b8403b
No related branches found
No related tags found
No related merge requests found
......@@ -195,11 +195,8 @@ unsafe fn threshold_avx2(
} else {
let data = scores.matrix();
let rows = data.rows();
let mut indices = vec![0u32; data.columns() * rows];
let mut indices = vec![u32::MAX; data.columns() * rows];
unsafe {
// NOTE(@althonos): Using `u32::MAX` as a sentinel instead of `0`
// because `0` may be a valid index.
let max = _mm256_set1_epi32(u32::MAX as i32);
let t = _mm256_set1_ps(threshold);
let ones = _mm256_set1_epi32(1);
let mut dst = indices.as_mut_ptr() as *mut __m256i;
......@@ -252,23 +249,18 @@ unsafe fn threshold_avx2(
let r3 = _mm256_load_ps(row[0x10..].as_ptr());
let r4 = _mm256_load_ps(row[0x18..].as_ptr());
// check whether scores are greater or equal to the threshold
let m1 = _mm256_castps_si256(_mm256_cmp_ps(r1, t, _CMP_LT_OS));
let m2 = _mm256_castps_si256(_mm256_cmp_ps(r2, t, _CMP_LT_OS));
let m3 = _mm256_castps_si256(_mm256_cmp_ps(r3, t, _CMP_LT_OS));
let m4 = _mm256_castps_si256(_mm256_cmp_ps(r4, t, _CMP_LT_OS));
// Mask indices that should be removed
let i1 = _mm256_blendv_epi8(x1, max, m1);
let i2 = _mm256_blendv_epi8(x2, max, m2);
let i3 = _mm256_blendv_epi8(x3, max, m3);
let i4 = _mm256_blendv_epi8(x4, max, m4);
// Store masked indices into the destination vector
_mm256_storeu_si256(dst, i1);
_mm256_storeu_si256(dst.add(1), i2);
_mm256_storeu_si256(dst.add(2), i3);
_mm256_storeu_si256(dst.add(3), i4);
// Advance result buffer to next row
let m1 = _mm256_castps_si256(_mm256_cmp_ps(r1, t, _CMP_GE_OS));
let m2 = _mm256_castps_si256(_mm256_cmp_ps(r2, t, _CMP_GE_OS));
let m3 = _mm256_castps_si256(_mm256_cmp_ps(r3, t, _CMP_GE_OS));
let m4 = _mm256_castps_si256(_mm256_cmp_ps(r4, t, _CMP_GE_OS));
// store masked indices into the destination vector
_mm256_maskstore_epi32(dst as *mut _, m1, x1);
_mm256_maskstore_epi32(dst.add(1) as *mut _, m2, x2);
_mm256_maskstore_epi32(dst.add(2) as *mut _, m3, x3);
_mm256_maskstore_epi32(dst.add(3) as *mut _, m4, x4);
// advance result buffer to next row
dst = dst.add(4);
// Advance sequence indices to next row
// advance sequence indices to next row
x1 = _mm256_add_epi32(x1, ones);
x2 = _mm256_add_epi32(x2, ones);
x3 = _mm256_add_epi32(x3, ones);
......@@ -279,13 +271,6 @@ unsafe fn threshold_avx2(
// NOTE: Benchmarks suggest that `indices.retain(...)` is faster than
// `indices.into_iter().filter(...).
// FIXME: The `Vec::retain` implementation may not be optimal for this,
// since it takes extra care of the vector elements deallocation
// because they may implement `Drop`. It may be faster to use
// a double-pointer algorithm, swapping sentinels and concrete
// values until the end of the vector is reached, and then
// clipping the vector with `indices.set_len`.
// Remove all masked items and convert the indices to usize
indices.retain(|&x| (x as usize) < scores.len());
indices.into_iter().map(|i| i as usize).collect()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment