diff --git a/lightmotif-bench/dna.rs b/lightmotif-bench/dna.rs index e4c924a2fbbbc387c911ad8db6d3c62ac9fb3207..4239d520f0ac2ff31413b1d42c245e5aecba420c 100644 --- a/lightmotif-bench/dna.rs +++ b/lightmotif-bench/dna.rs @@ -48,9 +48,9 @@ fn bench_generic(bencher: &mut test::Bencher) { }); } -#[cfg(target_feature = "avx2")] +#[cfg(target_feature = "ssse3")] #[bench] -fn bench_sse2(bencher: &mut test::Bencher) { +fn bench_ssse3(bencher: &mut test::Bencher) { let seq = &SEQUENCE[..10000]; let encoded = EncodedSequence::<Dna>::from_str(seq).unwrap(); let mut striped = encoded.to_striped(); @@ -69,7 +69,7 @@ fn bench_sse2(bencher: &mut test::Bencher) { bencher.bytes = seq.len() as u64; bencher.iter(|| { Pipeline::<_, __m128>::score_into(&striped, &pssm, &mut scores); - test::black_box(Pipeline::<_, __m128>::best_position(&scores)); + test::black_box(Pipeline::<_, __m128>::best_position(&scores).unwrap()); }); } diff --git a/lightmotif/src/pli.rs b/lightmotif/src/pli.rs index 9d66130f2a4a6a7d84a63b4a2e401ccf9ac3acac..f7ce67b53991aff3ea56d7373d219136b81cfdc3 100644 --- a/lightmotif/src/pli.rs +++ b/lightmotif/src/pli.rs @@ -175,10 +175,79 @@ unsafe fn score_ssse3( } // record the score for the current position let row = &mut scores.data[i]; - _mm_store_ps(row[0..].as_mut_ptr(), s1); - _mm_store_ps(row[4..].as_mut_ptr(), s2); - _mm_store_ps(row[8..].as_mut_ptr(), s3); - _mm_store_ps(row[12..].as_mut_ptr(), s4); + _mm_storeu_ps(row[0..].as_mut_ptr(), s1); + _mm_storeu_ps(row[4..].as_mut_ptr(), s2); + _mm_storeu_ps(row[8..].as_mut_ptr(), s3); + _mm_storeu_ps(row[12..].as_mut_ptr(), s4); + } +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[target_feature(enable = "ssse3")] +unsafe fn best_position_ssse3( + scores: &StripedScores<{ std::mem::size_of::<__m128i>() }>, +) -> Option<usize> { + if scores.length == 0 { + None + } else { + unsafe { + // the row index for the best score in each column + // (these are 32-bit integers but for use with `_mm256_blendv_ps` + // they get stored in 32-bit float vectors). + let mut p1 = _mm_setzero_ps(); + let mut p2 = _mm_setzero_ps(); + let mut p3 = _mm_setzero_ps(); + let mut p4 = _mm_setzero_ps(); + // store the best scores for each column + let mut s1 = _mm_loadu_ps(scores.data[0][0x00..].as_ptr()); + let mut s2 = _mm_loadu_ps(scores.data[0][0x04..].as_ptr()); + let mut s3 = _mm_loadu_ps(scores.data[0][0x08..].as_ptr()); + let mut s4 = _mm_loadu_ps(scores.data[0][0x0c..].as_ptr()); + // process all rows iteratively + for (i, row) in scores.data.iter().enumerate() { + // record the current row index + let index = _mm_castsi128_ps(_mm_set1_epi32(i as i32)); + // load scores for the current row + let r1 = _mm_loadu_ps(row[0x00..].as_ptr()); + let r2 = _mm_loadu_ps(row[0x04..].as_ptr()); + let r3 = _mm_loadu_ps(row[0x08..].as_ptr()); + let r4 = _mm_loadu_ps(row[0x0c..].as_ptr()); + // compare scores to local maximums + let c1 = _mm_cmplt_ps(s1, r1); + let c2 = _mm_cmplt_ps(s2, r2); + let c3 = _mm_cmplt_ps(s3, r3); + let c4 = _mm_cmplt_ps(s4, r4); + // NOTE: code below could use `_mm_blendv_ps` instead, + // but this instruction is only available on SSE4.1 + // while the rest of the code is actually using at + // most SSSE3 instructions. + // replace indices of new local maximums + p1 = _mm_or_ps(_mm_andnot_ps(p1, c1), _mm_and_ps(index, c1)); + p2 = _mm_or_ps(_mm_andnot_ps(p2, c2), _mm_and_ps(index, c2)); + p3 = _mm_or_ps(_mm_andnot_ps(p3, c3), _mm_and_ps(index, c3)); + p4 = _mm_or_ps(_mm_andnot_ps(p4, c4), _mm_and_ps(index, c4)); + // replace values of new local maximums + s1 = _mm_or_ps(_mm_andnot_ps(s1, c1), _mm_and_ps(r1, c1)); + s2 = _mm_or_ps(_mm_andnot_ps(s2, c2), _mm_and_ps(r2, c2)); + s3 = _mm_or_ps(_mm_andnot_ps(s3, c3), _mm_and_ps(r3, c3)); + s4 = _mm_or_ps(_mm_andnot_ps(s4, c4), _mm_and_ps(r4, c4)); + } + // find the global maximum across all columns + let mut x: [u32; 16] = [0; 16]; + _mm_storeu_si128(x[0x00..].as_mut_ptr() as *mut _, _mm_castps_si128(p1)); + _mm_storeu_si128(x[0x04..].as_mut_ptr() as *mut _, _mm_castps_si128(p2)); + _mm_storeu_si128(x[0x08..].as_mut_ptr() as *mut _, _mm_castps_si128(p3)); + _mm_storeu_si128(x[0x0c..].as_mut_ptr() as *mut _, _mm_castps_si128(p4)); + let mut best_pos = 0; + let mut best_score = -f32::INFINITY; + for (col, &row) in x.iter().enumerate() { + if scores.data[row as usize][col] > best_score { + best_score = scores.data[row as usize][col]; + best_pos = col * scores.data.rows() + row as usize; + } + } + Some(best_pos) + } } } @@ -211,6 +280,10 @@ impl Score<Dna, { Dna::K }, __m128, { std::mem::size_of::<__m128i>() }> for Pipe score_ssse3(seq, pssm, scores); } } + + fn best_position(scores: &StripedScores<{ std::mem::size_of::<__m128i>() }>) -> Option<usize> { + unsafe { best_position_ssse3(scores) } + } } // --- AVX2 -------------------------------------------------------------------- @@ -338,19 +411,19 @@ unsafe fn best_position_avx2( let mut p3 = _mm256_setzero_ps(); let mut p4 = _mm256_setzero_ps(); // store the best scores for each column - let mut s1 = _mm256_loadu_ps(scores.data[0][0x00..].as_ptr()); - let mut s2 = _mm256_loadu_ps(scores.data[0][0x04..].as_ptr()); - let mut s3 = _mm256_loadu_ps(scores.data[0][0x08..].as_ptr()); - let mut s4 = _mm256_loadu_ps(scores.data[0][0x0c..].as_ptr()); + let mut s1 = _mm256_load_ps(scores.data[0][0x00..].as_ptr()); + let mut s2 = _mm256_load_ps(scores.data[0][0x08..].as_ptr()); + let mut s3 = _mm256_load_ps(scores.data[0][0x10..].as_ptr()); + let mut s4 = _mm256_load_ps(scores.data[0][0x18..].as_ptr()); // process all rows iteratively for (i, row) in scores.data.iter().enumerate() { // record the current row index let index = _mm256_castsi256_ps(_mm256_set1_epi32(i as i32)); // load scores for the current row - let r1 = _mm256_loadu_ps(row[0x00..].as_ptr()); - let r2 = _mm256_loadu_ps(row[0x04..].as_ptr()); - let r3 = _mm256_loadu_ps(row[0x08..].as_ptr()); - let r4 = _mm256_loadu_ps(row[0x0c..].as_ptr()); + let r1 = _mm256_load_ps(row[0x00..].as_ptr()); + let r2 = _mm256_load_ps(row[0x08..].as_ptr()); + let r3 = _mm256_load_ps(row[0x10..].as_ptr()); + let r4 = _mm256_load_ps(row[0x18..].as_ptr()); // compare scores to local maximums let c1 = _mm256_cmp_ps(s1, r1, _CMP_LT_OS); let c2 = _mm256_cmp_ps(s2, r2, _CMP_LT_OS); @@ -370,9 +443,9 @@ unsafe fn best_position_avx2( // find the global maximum across all columns let mut x: [u32; 32] = [0; 32]; _mm256_storeu_si256(x[0x00..].as_mut_ptr() as *mut _, _mm256_castps_si256(p1)); - _mm256_storeu_si256(x[0x04..].as_mut_ptr() as *mut _, _mm256_castps_si256(p2)); - _mm256_storeu_si256(x[0x08..].as_mut_ptr() as *mut _, _mm256_castps_si256(p3)); - _mm256_storeu_si256(x[0x0c..].as_mut_ptr() as *mut _, _mm256_castps_si256(p4)); + _mm256_storeu_si256(x[0x08..].as_mut_ptr() as *mut _, _mm256_castps_si256(p2)); + _mm256_storeu_si256(x[0x10..].as_mut_ptr() as *mut _, _mm256_castps_si256(p3)); + _mm256_storeu_si256(x[0x18..].as_mut_ptr() as *mut _, _mm256_castps_si256(p4)); let mut best_pos = 0; let mut best_score = -f32::INFINITY; for (col, &row) in x.iter().enumerate() { diff --git a/lightmotif/tests/dna.rs b/lightmotif/tests/dna.rs index 3ee4c961b080da1cd7b6b1ddb92b1103ab2dba33..ae7e81eecb5626f22ad793f1181071ce721f6da3 100644 --- a/lightmotif/tests/dna.rs +++ b/lightmotif/tests/dna.rs @@ -76,7 +76,7 @@ fn test_best_position_generic() { .map(|x| EncodedSequence::from_str(x).unwrap()), ) .unwrap(); - let pbm = cm.to_freq([0.1, 0.1, 0.1, 0.1, 0.0]); + let pbm = cm.to_freq(0.1); let pwm = pbm.to_weight(None); let pssm = pwm.into(); @@ -124,6 +124,27 @@ fn test_score_ssse3() { } } +#[cfg(target_feature = "ssse3")] +#[test] +fn test_best_position_ssse3() { + let encoded = EncodedSequence::<Dna>::from_str(SEQUENCE).unwrap(); + let mut striped = encoded.to_striped(); + + let cm = CountMatrix::<Dna, { Dna::K }>::from_sequences( + PATTERNS + .iter() + .map(|x| EncodedSequence::from_str(x).unwrap()), + ) + .unwrap(); + let pbm = cm.to_freq(0.1); + let pwm = pbm.to_weight(None); + let pssm = pwm.into(); + + striped.configure(&pssm); + let result = Pipeline::<Dna, __m128>::score(&striped, &pssm); + assert_eq!(Pipeline::<Dna, __m128>::best_position(&result), Some(18)); +} + #[cfg(target_feature = "avx2")] #[test] fn test_score_avx2() { @@ -136,7 +157,7 @@ fn test_score_avx2() { .map(|x| EncodedSequence::from_str(x).unwrap()), ) .unwrap(); - let pbm = cm.to_freq([0.1, 0.1, 0.1, 0.1, 0.0]); + let pbm = cm.to_freq(0.1); let pwm = pbm.to_weight(None); let pssm = pwm.into(); @@ -144,10 +165,6 @@ fn test_score_avx2() { let result = Pipeline::<_, __m256>::score(&striped, &pssm); let scores = result.to_vec(); - // for i in 0..result.data.rows() { - // println!("{:?}", &result.data[i]); - // } - assert_eq!(scores.len(), EXPECTED.len()); for i in 0..EXPECTED.len() { assert!(