Skip to content
Snippets Groups Projects
Commit 9f3f67e4 authored by Martin Larralde's avatar Martin Larralde
Browse files

Add SSSE3 implementation using less optimized masked additions

parent 0aaa8df6
No related branches found
No related tags found
No related merge requests found
......@@ -71,18 +71,20 @@ motif from [PRODORIC](https://www.prodoric.de/), and the
- Score every position of the genome with the motif weight matrix:
```console
running 2 tests
test bench_avx ... bench: 13,294,196 ns/iter (+/- 73,022) = 349 MB/s
test bench_generic ... bench: 316,647,932 ns/iter (+/- 1,420,798) = 14 MB/s
running 3 tests
test bench_avx2 ... bench: 13,053,752 ns/iter (+/- 45,411) = 355 MB/s
test bench_ssse3 ... bench: 37,203,277 ns/iter (+/- 2,416,572) = 124 MB/s
test bench_generic ... bench: 314,682,807 ns/iter (+/- 1,072,174) = 14 MB/s
```
- Find the highest-scoring position for a motif in a sequence
- Find the highest-scoring position for a motif in a 10kb sequence
(compared to the PSSM algorithm implemented in
[`bio::pattern_matching::pssm`](https://docs.rs/bio/1.1.0/bio/pattern_matching/pssm/index.html)):
```console
test bench_avx ... bench: 47,069 ns/iter (+/- 10) = 212 MB/s
test bench_bio ... bench: 1,437,308 ns/iter (+/- 5,419) = 6 MB/s
test bench_generic ... bench: 740,348 ns/iter (+/- 2,277) = 13 MB/s
test bench_avx2 ... bench: 46,390 ns/iter (+/- 115) = 215 MB/s
test bench_ssse3 ... bench: 97,691 ns/iter (+/- 2,720) = 102 MB/s
test bench_generic ... bench: 740,305 ns/iter (+/- 2,527) = 13 MB/s
test bench_bio ... bench: 1,575,504 ns/iter (+/- 2,799) = 6 MB/s
```
......
......@@ -8,6 +8,9 @@ extern crate test;
#[cfg(target_feature = "avx2")]
use std::arch::x86_64::__m256;
#[cfg(target_feature = "sse2")]
use std::arch::x86_64::__m128;
use lightmotif::Alphabet;
use lightmotif::Background;
use lightmotif::CountMatrix;
......@@ -45,10 +48,37 @@ fn bench_generic(bencher: &mut test::Bencher) {
#[cfg(target_feature = "avx2")]
#[bench]
fn bench_avx(bencher: &mut test::Bencher) {
fn bench_sse2(bencher: &mut test::Bencher) {
let seq = &SEQUENCE[..10000];
let encoded = EncodedSequence::<DnaAlphabet>::from_text(seq).unwrap();
let mut striped = encoded.to_striped::<32>();
let mut striped = encoded.to_striped();
let bg = Background::<DnaAlphabet, { DnaAlphabet::K }>::uniform();
let cm = CountMatrix::<DnaAlphabet, { DnaAlphabet::K }>::from_sequences(&[
EncodedSequence::from_text("GTTGACCTTATCAAC").unwrap(),
EncodedSequence::from_text("GTTGATCCAGTCAAC").unwrap(),
])
.unwrap();
let pbm = cm.to_probability(0.1);
let pwm = pbm.to_weight(bg);
striped.configure(&pwm);
let pli = Pipeline::<_, __m128>::new();
let mut scores = StripedScores::new_for(&striped, &pwm);
bencher.bytes = seq.len() as u64;
bencher.iter(|| {
pli.score_into(&striped, &pwm, &mut scores);
test::black_box(scores.argmax());
});
}
#[cfg(target_feature = "avx2")]
#[bench]
fn bench_avx2(bencher: &mut test::Bencher) {
let seq = &SEQUENCE[..10000];
let encoded = EncodedSequence::<DnaAlphabet>::from_text(seq).unwrap();
let mut striped = encoded.to_striped();
let bg = Background::<DnaAlphabet, { DnaAlphabet::K }>::uniform();
let cm = CountMatrix::<DnaAlphabet, { DnaAlphabet::K }>::from_sequences(&[
......
......@@ -3,6 +3,9 @@
extern crate lightmotif;
extern crate test;
#[cfg(target_feature = "ssse3")]
use std::arch::x86_64::__m128;
#[cfg(target_feature = "avx2")]
use std::arch::x86_64::__m256;
......@@ -37,14 +40,37 @@ fn bench_generic(bencher: &mut test::Bencher) {
bencher.iter(|| test::black_box(pli.score_into(&striped, &pwm, &mut scores)));
}
#[cfg(target_feature = "ssse3")]
#[bench]
fn bench_ssse3(bencher: &mut test::Bencher) {
let encoded = EncodedSequence::<DnaAlphabet>::from_text(SEQUENCE).unwrap();
let mut striped = encoded.to_striped();
let bg = Background::uniform();
let cm = CountMatrix::from_sequences(&[
EncodedSequence::from_text("GTTGACCTTATCAAC").unwrap(),
EncodedSequence::from_text("GTTGATCCAGTCAAC").unwrap(),
])
.unwrap();
let pbm = cm.to_probability(0.1);
let pwm = pbm.to_weight(bg);
striped.configure(&pwm);
let pli = Pipeline::<_, __m128>::new();
let mut scores = StripedScores::new_for(&striped, &pwm);
bencher.bytes = SEQUENCE.len() as u64;
bencher.iter(|| test::black_box(pli.score_into(&striped, &pwm, &mut scores)));
}
#[cfg(target_feature = "avx2")]
#[bench]
fn bench_avx(bencher: &mut test::Bencher) {
fn bench_avx2(bencher: &mut test::Bencher) {
let encoded = EncodedSequence::<DnaAlphabet>::from_text(SEQUENCE).unwrap();
let mut striped = encoded.to_striped::<32>();
let mut striped = encoded.to_striped();
let bg = Background::<DnaAlphabet, { DnaAlphabet::K }>::uniform();
let cm = CountMatrix::<DnaAlphabet, { DnaAlphabet::K }>::from_sequences(&[
let bg = Background::uniform();
let cm = CountMatrix::from_sequences(&[
EncodedSequence::from_text("GTTGACCTTATCAAC").unwrap(),
EncodedSequence::from_text("GTTGATCCAGTCAAC").unwrap(),
])
......
#[cfg(target_feature = "avx2")]
#[cfg(target_feature = "ssse3")]
use std::arch::x86_64::*;
use self::seal::Vector;
......@@ -16,6 +16,9 @@ mod seal {
#[cfg(target_feature = "avx2")]
impl Vector for std::arch::x86_64::__m256 {}
#[cfg(target_feature = "ssse3")]
impl Vector for std::arch::x86_64::__m128 {}
}
pub struct Pipeline<A: Alphabet, V: Vector> {
......@@ -76,6 +79,112 @@ impl Pipeline<DnaAlphabet, f32> {
}
}
#[cfg(target_feature = "ssse3")]
impl Pipeline<DnaAlphabet, __m128> {
pub fn score_into<S, M>(
&self,
seq: S,
pwm: M,
scores: &mut StripedScores<__m128, { std::mem::size_of::<__m128i>() }>,
) where
S: AsRef<StripedSequence<DnaAlphabet, { std::mem::size_of::<__m128i>() }>>,
M: AsRef<WeightMatrix<DnaAlphabet, { DnaAlphabet::K }>>,
{
let seq = seq.as_ref();
let pwm = pwm.as_ref();
let result = &mut scores.data;
if seq.wrap < pwm.len() - 1 {
panic!("not enough wrapping rows for motif of length {}", pwm.len());
}
if result.rows() < (seq.data.rows() - seq.wrap) {
panic!("not enough rows for scores: {}", pwm.len());
}
scores.length = seq.length - pwm.len() + 1;
unsafe {
// mask vectors for broadcasting uint8x16_t to uint32x4_t to floatx4_t
let m1 = _mm_set_epi32(
0xFFFFFF03u32 as i32,
0xFFFFFF02u32 as i32,
0xFFFFFF01u32 as i32,
0xFFFFFF00u32 as i32,
);
let m2 = _mm_set_epi32(
0xFFFFFF07u32 as i32,
0xFFFFFF06u32 as i32,
0xFFFFFF05u32 as i32,
0xFFFFFF04u32 as i32,
);
let m3 = _mm_set_epi32(
0xFFFFFF0Bu32 as i32,
0xFFFFFF0Au32 as i32,
0xFFFFFF09u32 as i32,
0xFFFFFF08u32 as i32,
);
let m4 = _mm_set_epi32(
0xFFFFFF0Fu32 as i32,
0xFFFFFF0Eu32 as i32,
0xFFFFFF0Du32 as i32,
0xFFFFFF0Cu32 as i32,
);
//
// process every position of the sequence data
for i in 0..seq.data.rows() - seq.wrap {
// reset sums for current position
let mut s1 = _mm_setzero_ps();
let mut s2 = _mm_setzero_ps();
let mut s3 = _mm_setzero_ps();
let mut s4 = _mm_setzero_ps();
// advance position in the position weight matrix
for j in 0..pwm.len() {
// load sequence row and broadcast to f32
let x = _mm_load_si128(seq.data[i + j].as_ptr() as *const __m128i);
let x1 = _mm_shuffle_epi8(x, m1);
let x2 = _mm_shuffle_epi8(x, m2);
let x3 = _mm_shuffle_epi8(x, m3);
let x4 = _mm_shuffle_epi8(x, m4);
// load row for current weight matrix position
let row = pwm.data[j].as_ptr();
// index lookup table with each bases incrementally
for i in 0..DnaAlphabet::K {
let sym = _mm_set1_epi32(i as i32);
let lut = _mm_set1_ps(*row.add(i as usize));
let p1 = _mm_castsi128_ps(_mm_cmpeq_epi32(x1, sym));
let p2 = _mm_castsi128_ps(_mm_cmpeq_epi32(x2, sym));
let p3 = _mm_castsi128_ps(_mm_cmpeq_epi32(x3, sym));
let p4 = _mm_castsi128_ps(_mm_cmpeq_epi32(x4, sym));
s1 = _mm_add_ps(s1, _mm_and_ps(lut, p1));
s2 = _mm_add_ps(s2, _mm_and_ps(lut, p2));
s3 = _mm_add_ps(s3, _mm_and_ps(lut, p3));
s4 = _mm_add_ps(s4, _mm_and_ps(lut, p4));
}
}
// record the score for the current position
let row = &mut result[i];
_mm_store_ps(row[0..].as_mut_ptr(), s1);
_mm_store_ps(row[4..].as_mut_ptr(), s2);
_mm_store_ps(row[8..].as_mut_ptr(), s3);
_mm_store_ps(row[12..].as_mut_ptr(), s4);
}
}
}
pub fn score<S, M>(
&self,
seq: S,
pwm: M,
) -> StripedScores<__m128, { std::mem::size_of::<__m128i>() }>
where
S: AsRef<StripedSequence<DnaAlphabet, { std::mem::size_of::<__m128i>() }>>,
M: AsRef<WeightMatrix<DnaAlphabet, { DnaAlphabet::K }>>,
{
let mut scores = StripedScores::new_for(&seq, &pwm);
self.score_into(seq, pwm, &mut scores);
scores
}
}
#[cfg(target_feature = "avx2")]
impl Pipeline<DnaAlphabet, __m256> {
pub fn score_into<S, M>(
......@@ -287,6 +396,43 @@ impl<const C: usize> StripedScores<f32, C> {
}
}
#[cfg(target_feature = "ssse3")]
impl<const C: usize> StripedScores<__m128, C> {
/// Convert the striped scores to a vector of scores.
pub fn to_vec(&self) -> Vec<f32> {
let mut vec = Vec::with_capacity(self.length);
for i in 0..self.length {
let col = i / self.data.rows();
let row = i % self.data.rows();
vec.push(self.data[row][col]);
}
vec
}
/// Get the index of the highest scoring position.
///
/// ## Panic
/// Panics if the data buffer is empty.
pub fn argmax(&self) -> usize {
let mut best_pos = 0;
let mut best_score = self.data[0][0];
let mut col = 0;
let mut row = 0;
for i in 0..self.length {
if self.data[row][col] > best_score {
best_pos = i;
best_score = self.data[row][col];
}
row += 1;
if row == self.data.rows() {
row = 0;
col += 1;
}
}
best_pos
}
}
#[cfg(target_feature = "avx2")]
impl<const C: usize> StripedScores<__m256, C> {
/// Convert the striped scores to a vector of scores.
......
extern crate lightmotif;
#[cfg(target_feature = "ssse3")]
use std::arch::x86_64::__m128;
#[cfg(target_feature = "avx2")]
use std::arch::x86_64::__m256;
......@@ -62,10 +64,45 @@ fn test_score_generic() {
// assert_eq!(result.data[0][0], -23.07094); // -23.07094
}
#[cfg(target_feature = "ssse3")]
#[test]
fn test_score_ssse3() {
let encoded = EncodedSequence::<DnaAlphabet>::from_text(SEQUENCE).unwrap();
let mut striped = encoded.to_striped();
let cm = CountMatrix::from_sequences(
PATTERNS
.iter()
.map(|x| EncodedSequence::from_text(x).unwrap()),
)
.unwrap();
let pbm = cm.to_probability([0.1, 0.1, 0.1, 0.1, 0.0]);
let pwm = pbm.to_weight([0.25, 0.25, 0.25, 0.25, 0.0]);
striped.configure(&pwm);
let pli = Pipeline::<_, __m128>::new();
let result = pli.score(&striped, &pwm);
// for i in 0..result.data.rows() {
// println!("i={} {:?}", i, &result.data[i]);
// }
let scores = result.to_vec();
assert_eq!(scores.len(), EXPECTED.len());
for i in 0..EXPECTED.len() {
assert!(
(scores[i] - EXPECTED[i]).abs() < 1e-5,
"{} != {} at position {}",
scores[i],
EXPECTED[i],
i
);
}
}
#[cfg(target_feature = "avx2")]
#[test]
fn test_score_avx2() {
// let seq = "ATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCGCATCCAGCTCGCCCTCCAAGCTGTGGAGTCTGCTCCAGGAACCTCGCGACGTCCTCAAGAAGTTCGACCCAGACCGCCTCAACCTGAAACGATTCCATCATACCAACGGTGACACTCACGGTGCGACCGACGTCAACAACAAATCATATCTCCTCGAAGAAAACACCCGACTCTTCGATGCCTCGTTCTTCGGAATCAGCCCCCTGGAGGCGGCCGGTATGGACCCCCAGCAGCGTCTGTTGCTGGAAACCGTCTACGAGTCGTTTGAGGCGGCTGGCGTGACCCTCGATCAGCTCAAGGGTTCTTTGACCTCGGTTCATGTTGGCGTCATGACCAACGACTACTCCTTTATCCAGCTCCGTGACCCAGAAACGCTGTCGAAGTACAACGCGACTGGCACGGCCAACAGCATCATGTCGAACCGTATTTCATATGTCTTTGACTTGAAAGGTCCATCAGAGACCATCGACACGGCGTGCTCCAGCTCGCTGGTCGCCCTGCACCACGCTGCTCAGGGCCTGCTCAGCGGCGACTGCGAGACTGCCGTCGTCGCCGGCGTCAACCTCATCTTCGACCCCTCTCCATACATCACAGAGTCCAAGCTACACATGCTGTCACCCGACTCCCAGTCTCGCATGTGGGACAAGTCTGCAAATGGCTACGCCCGCGGCGAGGGCGCTGCCGCGCTGCTCCTGAAGCCCCTCAGCCGCGCCCTGAGGGACGGCGATCACATCGAGGGCATTGTCCGAGGCACAGGAGTCAACTCGGACGGCCAGAGCTCCGGCATCACCATGCCTTTTGCCCCTGCGCAGTCGGCGCTCATTCGCCAAACTTATCTCCGTGCTGGCCTCGACCCGATCAAGGACCGGCCTCAGTACTTCGAGTGCCACGGCACCGGAACTCCAGCTGGTGACCCCGTGGAAGCGCGAGCCATCAGCGAGTCGTTGTTGGACGGTGAAAATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCGCATCCAGCTCGCCCTCCAAGCTGTGGAGTCTGCTCCAGGAACCTCGCGACGTCCTCAAGAAGTTCGACCCAGACCGCCTCAACCTGAAACGATTCCATCATACCAACGGTGACACTCACGGTGCGACCGACGTCAACAACAAATCATATCTCCTCGAAGAAAACACCCGACTCTTCGATGCCTCGTTCTTCGGAATCAGCCCCCTGGAGGCGGCCGGTATGGACCCCCAGCAGCGTCTGTTGCTGGAAACCGTCTACGAGTCGTTTGAGGCGGCTGGCGTGACCCTCGATCAGCTCAAGGGTTCTTTGACCTCGGTTCATGTTGGCGTCATGACCAACGACTACTCCTTTATCCAGCTCCGTGACCCAGAAACGCTGTCGAAGTACAACGCGACTGGCACGGCCAACAGCATCATGTCGAACCGTATTTCATATGTCTTTGACTTGAAAGGTCCATCAGAGACCATCGACACGGCGTGCTCCAGCTCGCTGGTCGCCCTGCACCACGCTGCTCAGGGCCTGCTCAGCGGCGACTGCGAGACTGCCGTCGTCGCCGGCGTCAACCTCATCTTCGACCCCTCTCCATACATCACAGAGTCCAAGCTACACATGCTGTCACCCGACTCCCAGTCTCGCATGTGGGACAAGTCTGCAAATGGCTACGCCCGCGGCGAGGGCGCTGCCGCGCTGCTCCTGAAGCCCCTCAGCCGCGCCCTGAGGGACGGCGATCACATCGAGGGCATTGTCCGAGGCACAGGAGTCAACTCGGACGGCCAGAGCTCCGGCATCACCATGCCTTTTGCCCCTGCGCAGTCGGCGCTCATTCGCCAAACTTATCTCCGTGCTGGCCTCGACCCGATCAAGGACCGGCCTCAGTACTTCGAGTGCCACGGCACCGGAACTCCAGCTGGTGACCCCGTGGAAGCGCGAGCCATCAGCGAGTCGTTGTTGGACGGTGAAA";
let encoded = EncodedSequence::<DnaAlphabet>::from_text(SEQUENCE).unwrap();
let mut striped = encoded.to_striped::<{ std::mem::size_of::<__m256>() }>();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment