diff --git a/lightmotif-py/lightmotif/lib.rs b/lightmotif-py/lightmotif/lib.rs index 054bab463f649a0c587bf15717e864d2c9b25849..802d7f2dd24d5fa92a9bfc0effa1476ed626dc4c 100644 --- a/lightmotif-py/lightmotif/lib.rs +++ b/lightmotif-py/lightmotif/lib.rs @@ -14,7 +14,6 @@ use lightmotif::pli::Score; use pyo3::exceptions::PyBufferError; use pyo3::exceptions::PyIndexError; -use pyo3::exceptions::PyKeyError; use pyo3::exceptions::PyTypeError; use pyo3::exceptions::PyValueError; use pyo3::ffi::Py_ssize_t; @@ -390,7 +389,7 @@ impl From<lightmotif::pli::StripedScores<C>> for StripedScores { let shape = [cols as Py_ssize_t, rows as Py_ssize_t]; let strides = [ std::mem::size_of::<f32>() as Py_ssize_t, - (cols.next_power_of_two() * std::mem::size_of::<f32>()) as Py_ssize_t, + (scores.matrix().stride() * std::mem::size_of::<f32>()) as Py_ssize_t, ]; // mask the remaining positions that are outside the sequence length let length = scores.len(); diff --git a/lightmotif/src/dense.rs b/lightmotif/src/dense.rs index c1eef30b7ee01d0a36d1c3428129dab93d37a3dc..0a823fd13cce32b04a2ee79f95d019a8bdd61711 100644 --- a/lightmotif/src/dense.rs +++ b/lightmotif/src/dense.rs @@ -21,7 +21,7 @@ pub type DefaultAlignment = _DefaultAlignment; // --- DenseMatrix ------------------------------------------------------------- -/// An aligned dense matrix of with a constant number of columns. +/// A memory-aligned dense matrix with a constant number of columns. #[derive(Debug, Clone, PartialEq, Eq)] pub struct DenseMatrix<T: Default + Copy, C: Unsigned, A: Unsigned = DefaultAlignment> { data: Vec<T>, @@ -101,9 +101,21 @@ impl<T: Default + Copy, C: Unsigned, A: Unsigned> DenseMatrix<T, C, A> { C::USIZE } - /// The effective number of columns in the matrix, counting alignment. + /// The stride of the matrix, as a number of elements. + /// + /// This may be different from the number of columns to account for memory + /// alignment constraints. Multiply by `std::mem::size_of::<T>()` to obtain + /// the stride in bytes. + /// + /// # Example + /// ```rust + /// # use typenum::{U43, U32}; + /// # use lightmotif::dense::DenseMatrix; + /// let d = DenseMatrix::<u8, U43, U32>::new(0); + /// assert_eq!(d.stride(), 64); + /// ``` #[inline] - pub const fn columns_effective(&self) -> usize { + pub const fn stride(&self) -> usize { let x = std::mem::size_of::<T>(); let c = C::USIZE * x; let b = c + (A::USIZE - c % A::USIZE) * ((c % A::USIZE) > 0) as usize; @@ -119,7 +131,7 @@ impl<T: Default + Copy, C: Unsigned, A: Unsigned> DenseMatrix<T, C, A> { /// Change the number of rows of the matrix. pub fn resize(&mut self, rows: usize) { // Always over-allocate columns to avoid alignment issues. - let c: usize = self.columns_effective(); + let c: usize = self.stride(); // Cache previous dimensions let previous_rows = self.rows; @@ -167,7 +179,7 @@ impl<T: Default + Copy, C: Unsigned, A: Unsigned> Index<usize> for DenseMatrix<T type Output = [T]; #[inline] fn index(&self, index: usize) -> &Self::Output { - let c = self.columns_effective(); + let c = self.stride(); let row = self.offset + c * index; &self.data[row..row + C::USIZE] } @@ -176,7 +188,7 @@ impl<T: Default + Copy, C: Unsigned, A: Unsigned> Index<usize> for DenseMatrix<T impl<T: Default + Copy, C: Unsigned, A: Unsigned> IndexMut<usize> for DenseMatrix<T, C, A> { #[inline] fn index_mut(&mut self, index: usize) -> &mut Self::Output { - let c = self.columns_effective(); + let c = self.stride(); let row = self.offset + c * index; &mut self.data[row..row + C::USIZE] } @@ -233,7 +245,7 @@ where #[inline] fn get(&mut self, i: usize) -> &'a [T] { - let c = self.matrix.columns_effective(); + let c = self.matrix.stride(); let row = self.matrix.offset + c * i; unsafe { std::slice::from_raw_parts(self.data.as_ptr().add(row), C::USIZE) } } @@ -270,7 +282,7 @@ where #[inline] fn get(&mut self, i: usize) -> &'a mut [T] { - let c = self.matrix.columns_effective(); + let c = self.matrix.stride(); let row = self.matrix.offset + c * i; unsafe { std::slice::from_raw_parts_mut(self.data.as_ptr().add(row), C::USIZE) } } @@ -338,27 +350,27 @@ mod test { use super::*; #[test] - fn test_columns_effective() { + fn test_stride() { let d1 = DenseMatrix::<u8, U32, U32>::new(0); - assert_eq!(d1.columns_effective(), 32); + assert_eq!(d1.stride(), 32); let d2 = DenseMatrix::<u8, U16, U32>::new(0); - assert_eq!(d2.columns_effective(), 32); + assert_eq!(d2.stride(), 32); let d3 = DenseMatrix::<u32, U32, U32>::new(0); - assert_eq!(d3.columns_effective(), 32); + assert_eq!(d3.stride(), 32); let d4 = DenseMatrix::<u32, U8, U32>::new(0); - assert_eq!(d4.columns_effective(), 8); + assert_eq!(d4.stride(), 8); let d5 = DenseMatrix::<u32, U16, U32>::new(0); - assert_eq!(d5.columns_effective(), 16); + assert_eq!(d5.stride(), 16); let d6 = DenseMatrix::<u32, U3, U16>::new(0); - assert_eq!(d6.columns_effective(), 4); + assert_eq!(d6.stride(), 4); let d7 = DenseMatrix::<u8, U15, U8>::new(0); - assert_eq!(d7.columns_effective(), 16); + assert_eq!(d7.stride(), 16); } #[test] diff --git a/lightmotif/src/pli/platform/avx2.rs b/lightmotif/src/pli/platform/avx2.rs index b39802aa186830170d9773c7b2653a2493544fa7..f9e76d6312e2493c609c96e55351502363f6a9b5 100644 --- a/lightmotif/src/pli/platform/avx2.rs +++ b/lightmotif/src/pli/platform/avx2.rs @@ -98,8 +98,8 @@ unsafe fn score_avx2( s3 = _mm256_add_ps(s3, b3); s4 = _mm256_add_ps(s4, b4); // advance to next row in PSSM and sequence matrices - seqptr = seqptr.add(seq.data.columns_effective()); - pssmptr = pssmptr.add(pssm.weights().columns_effective()); + seqptr = seqptr.add(seq.data.stride()); + pssmptr = pssmptr.add(pssm.weights().stride()); } // permute lanes so that scores are in the right order let r1 = _mm256_permute2f128_ps(s1, s2, 0x20); @@ -166,7 +166,7 @@ unsafe fn best_position_avx2(scores: &StripedScores<<Avx2 as Backend>::LANES>) - s3 = _mm256_blendv_ps(s3, r3, c3); s4 = _mm256_blendv_ps(s4, r4, c4); // advance to next row - dataptr = dataptr.add(data.columns_effective()); + dataptr = dataptr.add(data.stride()); } // find the global maximum across all columns let mut x: [u32; 32] = [0; 32]; @@ -275,7 +275,7 @@ unsafe fn threshold_avx2( x3 = _mm256_add_epi32(x3, ones); x4 = _mm256_add_epi32(x4, ones); // Advance data pointer to next row - dataptr = dataptr.add(data.columns_effective()); + dataptr = dataptr.add(data.stride()); } } diff --git a/lightmotif/src/pli/platform/neon.rs b/lightmotif/src/pli/platform/neon.rs index 3d89d6493663c2621a61c85657e625e7d2e28a63..0e356a78577e71085eefd56bb572ab3137fa1a53 100644 --- a/lightmotif/src/pli/platform/neon.rs +++ b/lightmotif/src/pli/platform/neon.rs @@ -78,8 +78,8 @@ unsafe fn score_neon<A, C>( s.3 = vaddq_f32(s.3, vreinterpretq_f32_u32(vandq_u32(lut, p4))); } // advance to next row in sequence and PSSM matrices - dataptr = dataptr.add(seq.data.columns_effective()); - pssmptr = pssmptr.add(pssm.weights().columns_effective()); + dataptr = dataptr.add(seq.data.stride()); + pssmptr = pssmptr.add(pssm.weights().stride()); } // record the score for the current position let row = &mut data[i]; diff --git a/lightmotif/src/pli/platform/sse2.rs b/lightmotif/src/pli/platform/sse2.rs index 5ce58f62f2797082c796cea40626bed550afff5b..897a037abea513ef61b4e6b5cb6e25c7a96c113e 100644 --- a/lightmotif/src/pli/platform/sse2.rs +++ b/lightmotif/src/pli/platform/sse2.rs @@ -77,8 +77,8 @@ unsafe fn score_sse2<A, C>( s4 = _mm_add_ps(s4, _mm_and_ps(lut, p4)); } // advance to next row in sequence and PSSM matrices - dataptr = dataptr.add(seq.data.columns_effective()); - pssmptr = pssmptr.add(pssm.weights().columns_effective()); + dataptr = dataptr.add(seq.data.stride()); + pssmptr = pssmptr.add(pssm.weights().stride()); } // record the score for the current position let row = &mut data[i]; @@ -154,7 +154,7 @@ where s3 = _mm_or_ps(_mm_andnot_ps(c3, s3), _mm_and_ps(r3, c3)); s4 = _mm_or_ps(_mm_andnot_ps(c4, s4), _mm_and_ps(r4, c4)); // advance to next row - dataptr = dataptr.add(data.columns_effective()); + dataptr = dataptr.add(data.stride()); } // find the global maximum across all columns _mm_storeu_si128( @@ -274,7 +274,7 @@ where x3 = _mm_add_epi32(x3, ones); x4 = _mm_add_epi32(x4, ones); // Advance data pointer to next row - dataptr = dataptr.add(data.columns_effective()); + dataptr = dataptr.add(data.stride()); } } }