From 7f70de3bbfc9476b923f669f5a4c05523576eb0b Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Wed, 10 May 2023 16:30:10 +0200 Subject: [PATCH] Add method to extract the best position from a `StripedScores` matrix with AVX2 --- lightmotif-bench/dna.rs | 8 +-- lightmotif/src/pli.rs | 110 +++++++++++++++++++++++++++++++++------- lightmotif/tests/dna.rs | 44 ++++++++++++++++ 3 files changed, 140 insertions(+), 22 deletions(-) diff --git a/lightmotif-bench/dna.rs b/lightmotif-bench/dna.rs index 0aed75a..e4c924a 100644 --- a/lightmotif-bench/dna.rs +++ b/lightmotif-bench/dna.rs @@ -42,7 +42,9 @@ fn bench_generic(bencher: &mut test::Bencher) { bencher.bytes = seq.len() as u64; bencher.iter(|| { Pipeline::<_, f32>::score_into(&striped, &pssm, &mut scores); - test::black_box(scores.argmax()); + test::black_box( + <Pipeline<_, f32> as Score<Dna, { Dna::K }, f32, 32>>::best_position(&scores), + ); }); } @@ -67,7 +69,7 @@ fn bench_sse2(bencher: &mut test::Bencher) { bencher.bytes = seq.len() as u64; bencher.iter(|| { Pipeline::<_, __m128>::score_into(&striped, &pssm, &mut scores); - test::black_box(scores.argmax()); + test::black_box(Pipeline::<_, __m128>::best_position(&scores)); }); } @@ -92,7 +94,7 @@ fn bench_avx2(bencher: &mut test::Bencher) { bencher.bytes = seq.len() as u64; bencher.iter(|| { Pipeline::<_, __m256>::score_into(&striped, &pssm, &mut scores); - test::black_box(scores.argmax()); + test::black_box(Pipeline::<_, __m256>::best_position(&scores).unwrap()); }); } diff --git a/lightmotif/src/pli.rs b/lightmotif/src/pli.rs index 67a5519..9d66130 100644 --- a/lightmotif/src/pli.rs +++ b/lightmotif/src/pli.rs @@ -30,11 +30,13 @@ mod seal { /// Generic trait for computing sequence scores with a PSSM. pub trait Score<A: Alphabet, const K: usize, V: Vector, const C: usize> { + /// Compute the PSSM scores into the given buffer. fn score_into<S, M>(seq: S, pssm: M, scores: &mut StripedScores<C>) where S: AsRef<StripedSequence<A, C>>, M: AsRef<ScoringMatrix<A, K>>; + /// Compute the PSSM scores for every sequence positions. fn score<S, M>(seq: S, pssm: M) -> StripedScores<C> where S: AsRef<StripedSequence<A, C>>, @@ -44,6 +46,26 @@ pub trait Score<A: Alphabet, const K: usize, V: Vector, const C: usize> { Self::score_into(seq, pssm, &mut scores); scores } + + /// Find the sequence position with the highest score. + fn best_position(scores: &StripedScores<C>) -> Option<usize> { + if scores.length == 0 { + return None; + } + + let mut best_pos = 0; + let mut best_score = scores.data[0][0]; + for i in 0..scores.length { + let col = i / scores.data.rows(); + let row = i % scores.data.rows(); + if scores.data[row][col] > best_score { + best_score = scores.data[row][col]; + best_pos = i; + } + } + + Some(best_pos) + } } // --- Pipeline ---------------------------------------------------------------- @@ -299,6 +321,71 @@ unsafe fn score_avx2( } } +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[target_feature(enable = "avx2")] +unsafe fn best_position_avx2( + scores: &StripedScores<{ std::mem::size_of::<__m256i>() }>, +) -> Option<usize> { + if scores.length == 0 { + None + } else { + unsafe { + // the row index for the best score in each column + // (these are 32-bit integers but for use with `_mm256_blendv_ps` + // they get stored in 32-bit float vectors). + let mut p1 = _mm256_setzero_ps(); + let mut p2 = _mm256_setzero_ps(); + let mut p3 = _mm256_setzero_ps(); + let mut p4 = _mm256_setzero_ps(); + // store the best scores for each column + let mut s1 = _mm256_loadu_ps(scores.data[0][0x00..].as_ptr()); + let mut s2 = _mm256_loadu_ps(scores.data[0][0x04..].as_ptr()); + let mut s3 = _mm256_loadu_ps(scores.data[0][0x08..].as_ptr()); + let mut s4 = _mm256_loadu_ps(scores.data[0][0x0c..].as_ptr()); + // process all rows iteratively + for (i, row) in scores.data.iter().enumerate() { + // record the current row index + let index = _mm256_castsi256_ps(_mm256_set1_epi32(i as i32)); + // load scores for the current row + let r1 = _mm256_loadu_ps(row[0x00..].as_ptr()); + let r2 = _mm256_loadu_ps(row[0x04..].as_ptr()); + let r3 = _mm256_loadu_ps(row[0x08..].as_ptr()); + let r4 = _mm256_loadu_ps(row[0x0c..].as_ptr()); + // compare scores to local maximums + let c1 = _mm256_cmp_ps(s1, r1, _CMP_LT_OS); + let c2 = _mm256_cmp_ps(s2, r2, _CMP_LT_OS); + let c3 = _mm256_cmp_ps(s3, r3, _CMP_LT_OS); + let c4 = _mm256_cmp_ps(s4, r4, _CMP_LT_OS); + // replace indices of new local maximums + p1 = _mm256_blendv_ps(p1, index, c1); + p2 = _mm256_blendv_ps(p2, index, c2); + p3 = _mm256_blendv_ps(p3, index, c3); + p4 = _mm256_blendv_ps(p4, index, c4); + // replace values of new local maximums + s1 = _mm256_blendv_ps(s1, r1, c1); + s2 = _mm256_blendv_ps(s2, r2, c2); + s3 = _mm256_blendv_ps(s3, r3, c3); + s4 = _mm256_blendv_ps(s4, r4, c4); + } + // find the global maximum across all columns + let mut x: [u32; 32] = [0; 32]; + _mm256_storeu_si256(x[0x00..].as_mut_ptr() as *mut _, _mm256_castps_si256(p1)); + _mm256_storeu_si256(x[0x04..].as_mut_ptr() as *mut _, _mm256_castps_si256(p2)); + _mm256_storeu_si256(x[0x08..].as_mut_ptr() as *mut _, _mm256_castps_si256(p3)); + _mm256_storeu_si256(x[0x0c..].as_mut_ptr() as *mut _, _mm256_castps_si256(p4)); + let mut best_pos = 0; + let mut best_score = -f32::INFINITY; + for (col, &row) in x.iter().enumerate() { + if scores.data[row as usize][col] > best_score { + best_score = scores.data[row as usize][col]; + best_pos = col * scores.data.rows() + row as usize; + } + } + Some(best_pos) + } + } +} + /// Intel 256-bit vector implementation, for 32 elements column width. #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] impl Score<Dna, { Dna::K }, __m256, { std::mem::size_of::<__m256i>() }> for Pipeline<Dna, __m256> { @@ -329,6 +416,10 @@ impl Score<Dna, { Dna::K }, __m256, { std::mem::size_of::<__m256i>() }> for Pipe score_avx2(seq, pssm, scores); } } + + fn best_position(scores: &StripedScores<{ std::mem::size_of::<__m256i>() }>) -> Option<usize> { + unsafe { best_position_avx2(scores) } + } } // --- StripedScores ----------------------------------------------------------- @@ -403,25 +494,6 @@ impl<const C: usize> StripedScores<C> { } vec } - - /// Get the index of the highest scoring position, if any. - pub fn argmax(&self) -> Option<usize> { - if self.len() > 0 { - let mut best_pos = 0; - let mut best_score = self.data[0][0]; - for i in 0..self.length { - let col = i / self.data.rows(); - let row = i % self.data.rows(); - if self.data[row][col] > best_score { - best_pos = i; - best_score = self.data[row][col]; - } - } - Some(best_pos) - } else { - None - } - } } impl<const C: usize> AsRef<DenseMatrix<f32, C>> for StripedScores<C> { diff --git a/lightmotif/tests/dna.rs b/lightmotif/tests/dna.rs index d4b0455..3ee4c96 100644 --- a/lightmotif/tests/dna.rs +++ b/lightmotif/tests/dna.rs @@ -65,6 +65,29 @@ fn test_score_generic() { } } +#[test] +fn test_best_position_generic() { + let encoded = EncodedSequence::<Dna>::from_str(SEQUENCE).unwrap(); + let mut striped = encoded.to_striped::<2>(); + + let cm = CountMatrix::<Dna, { Dna::K }>::from_sequences( + PATTERNS + .iter() + .map(|x| EncodedSequence::from_str(x).unwrap()), + ) + .unwrap(); + let pbm = cm.to_freq([0.1, 0.1, 0.1, 0.1, 0.0]); + let pwm = pbm.to_weight(None); + let pssm = pwm.into(); + + striped.configure(&pssm); + let result = Pipeline::<Dna, f32>::score(&striped, &pssm); + assert_eq!( + <Pipeline::<Dna, f32> as Score<Dna, { Dna::K }, _, 2>>::best_position(&result), + Some(18) + ); +} + #[cfg(target_feature = "ssse3")] #[test] fn test_score_ssse3() { @@ -136,3 +159,24 @@ fn test_score_avx2() { ); } } + +#[cfg(target_feature = "avx2")] +#[test] +fn test_best_position_avx2() { + let encoded = EncodedSequence::<Dna>::from_str(SEQUENCE).unwrap(); + let mut striped = encoded.to_striped(); + + let cm = CountMatrix::<Dna, { Dna::K }>::from_sequences( + PATTERNS + .iter() + .map(|x| EncodedSequence::from_str(x).unwrap()), + ) + .unwrap(); + let pbm = cm.to_freq([0.1, 0.1, 0.1, 0.1, 0.0]); + let pwm = pbm.to_weight(None); + let pssm = pwm.into(); + + striped.configure(&pssm); + let result = Pipeline::<Dna, __m256>::score(&striped, &pssm); + assert_eq!(Pipeline::<Dna, __m256>::best_position(&result), Some(18)); +} -- GitLab