Skip to content
Snippets Groups Projects
Commit c6811bbd authored by Martin Larralde's avatar Martin Larralde
Browse files

Fix recovery of aligned scores in AVX2 implementation of the pipeline

parent 41113b90
No related branches found
No related tags found
No related merge requests found
...@@ -16,10 +16,10 @@ pub use abc::DnaSymbol; ...@@ -16,10 +16,10 @@ pub use abc::DnaSymbol;
pub use abc::Symbol; pub use abc::Symbol;
pub use dense::DenseMatrix; pub use dense::DenseMatrix;
pub use pli::Pipeline; pub use pli::Pipeline;
pub use pli::StripedScores;
pub use pwm::Background; pub use pwm::Background;
pub use pwm::CountMatrix; pub use pwm::CountMatrix;
pub use pwm::ProbabilityMatrix; pub use pwm::ProbabilityMatrix;
pub use pwm::StripedScores;
pub use pwm::WeightMatrix; pub use pwm::WeightMatrix;
pub use seq::EncodedSequence; pub use seq::EncodedSequence;
pub use seq::StripedSequence; pub use seq::StripedSequence;
...@@ -6,7 +6,6 @@ use super::abc::Alphabet; ...@@ -6,7 +6,6 @@ use super::abc::Alphabet;
use super::abc::DnaAlphabet; use super::abc::DnaAlphabet;
use super::abc::Symbol; use super::abc::Symbol;
use super::dense::DenseMatrix; use super::dense::DenseMatrix;
use super::pwm::StripedScores;
use super::pwm::WeightMatrix; use super::pwm::WeightMatrix;
use super::seq::EncodedSequence; use super::seq::EncodedSequence;
use super::seq::StripedSequence; use super::seq::StripedSequence;
...@@ -39,11 +38,11 @@ impl Pipeline<DnaAlphabet, f32> { ...@@ -39,11 +38,11 @@ impl Pipeline<DnaAlphabet, f32> {
&self, &self,
seq: &StripedSequence<DnaAlphabet, C>, seq: &StripedSequence<DnaAlphabet, C>,
pwm: &WeightMatrix<DnaAlphabet, { DnaAlphabet::K }>, pwm: &WeightMatrix<DnaAlphabet, { DnaAlphabet::K }>,
) -> StripedScores<C> { ) -> StripedScores<f32, C> {
let mut result = DenseMatrix::<f32, C>::new(seq.data.rows()); let mut result = DenseMatrix::<f32, C>::new(seq.data.rows());
for i in 0..seq.length - pwm.data.rows() + 1 { for i in 0..seq.length - pwm.len() + 1 {
let mut score = 0.0; let mut score = 0.0;
for j in 0..pwm.data.rows() { for j in 0..pwm.len() {
let offset = i + j; let offset = i + j;
let col = offset / seq.data.rows(); let col = offset / seq.data.rows();
let row = offset % seq.data.rows(); let row = offset % seq.data.rows();
...@@ -54,8 +53,9 @@ impl Pipeline<DnaAlphabet, f32> { ...@@ -54,8 +53,9 @@ impl Pipeline<DnaAlphabet, f32> {
result[row][col] = score; result[row][col] = score;
} }
StripedScores { StripedScores {
length: seq.length - pwm.data.rows() + 1, length: seq.length - pwm.len() + 1,
data: result, data: result,
marker: std::marker::PhantomData,
} }
} }
} }
...@@ -66,16 +66,18 @@ impl Pipeline<DnaAlphabet, __m256> { ...@@ -66,16 +66,18 @@ impl Pipeline<DnaAlphabet, __m256> {
&self, &self,
seq: &StripedSequence<DnaAlphabet, { std::mem::size_of::<__m256i>() }>, seq: &StripedSequence<DnaAlphabet, { std::mem::size_of::<__m256i>() }>,
pwm: &WeightMatrix<DnaAlphabet, { DnaAlphabet::K }>, pwm: &WeightMatrix<DnaAlphabet, { DnaAlphabet::K }>,
) -> StripedScores<{ std::mem::size_of::<__m256i>() }> { ) -> StripedScores<__m256, { std::mem::size_of::<__m256i>() }> {
const S: i32 = std::mem::size_of::<f32>() as i32; const S: i32 = std::mem::size_of::<f32>() as i32;
const C: usize = std::mem::size_of::<__m256i>(); const C: usize = std::mem::size_of::<__m256i>();
const K: usize = DnaAlphabet::K; const K: usize = DnaAlphabet::K;
let mut result = DenseMatrix::new((seq.length + C) / C); if (seq.wrap < pwm.len() - 1) {
panic!("not enough wrapping rows for motif of length {}", pwm.len());
}
let mut result = DenseMatrix::new(seq.data.rows() - seq.wrap);
unsafe { unsafe {
// get raw pointers to data // get raw pointers to data
let sdata = seq.data[0].as_ptr();
let mdata = pwm.data[0].as_ptr();
// mask vectors for broadcasting: // mask vectors for broadcasting:
let m1: __m256i = _mm256_set_epi32( let m1: __m256i = _mm256_set_epi32(
0xFFFFFF03u32 as i32, 0xFFFFFF03u32 as i32,
...@@ -118,13 +120,13 @@ impl Pipeline<DnaAlphabet, __m256> { ...@@ -118,13 +120,13 @@ impl Pipeline<DnaAlphabet, __m256> {
0xFFFFFF0Cu32 as i32, 0xFFFFFF0Cu32 as i32,
); );
// loop over every row of the sequence data // loop over every row of the sequence data
for i in 0..seq.data.rows() - pwm.data.rows() + 1 { for i in 0..seq.data.rows() - seq.wrap {
let mut s1 = _mm256_setzero_ps(); let mut s1 = _mm256_setzero_ps();
let mut s2 = _mm256_setzero_ps(); let mut s2 = _mm256_setzero_ps();
let mut s3 = _mm256_setzero_ps(); let mut s3 = _mm256_setzero_ps();
let mut s4 = _mm256_setzero_ps(); let mut s4 = _mm256_setzero_ps();
for j in 0..pwm.data.rows() { for j in 0..pwm.len() {
let x = _mm256_load_si256(seq.data[i + j].as_ptr() as *const __m256i); let x = _mm256_load_si256(seq.data[i+j].as_ptr() as *const __m256i);
let row = pwm.data[j].as_ptr(); let row = pwm.data[j].as_ptr();
// compute probabilities using an external lookup table // compute probabilities using an external lookup table
let p1 = _mm256_i32gather_ps(row, _mm256_shuffle_epi8(x, m1), S); let p1 = _mm256_i32gather_ps(row, _mm256_shuffle_epi8(x, m1), S);
...@@ -137,19 +139,64 @@ impl Pipeline<DnaAlphabet, __m256> { ...@@ -137,19 +139,64 @@ impl Pipeline<DnaAlphabet, __m256> {
s3 = _mm256_add_ps(s3, p3); s3 = _mm256_add_ps(s3, p3);
s4 = _mm256_add_ps(s4, p4); s4 = _mm256_add_ps(s4, p4);
} }
let row = &mut result[i]; let row = &mut result[i];
let rowptr = row.as_mut_ptr(); _mm256_storeu_ps(row[0..].as_mut_ptr(), s1);
_mm256_store_ps(rowptr, s1); _mm256_storeu_ps(row[8..].as_mut_ptr(), s2);
_mm256_store_ps(rowptr.add(0x08), s2); _mm256_storeu_ps(row[16..].as_mut_ptr(), s3);
_mm256_store_ps(rowptr.add(0x10), s3); _mm256_storeu_ps(row[24..].as_mut_ptr(), s4);
_mm256_store_ps(rowptr.add(0x18), s4);
} }
} }
StripedScores { StripedScores {
length: seq.length - pwm.data.rows() + 1, length: seq.length - pwm.len() + 1,
data: result, data: result,
marker: std::marker::PhantomData,
} }
} }
} }
#[derive(Clone, Debug)]
pub struct StripedScores<V: Vector, const C: usize = 32> {
pub length: usize,
pub data: DenseMatrix<f32, C>,
marker: std::marker::PhantomData<V>,
}
impl<const C: usize> StripedScores<f32, C> {
pub fn to_vec(&self) -> Vec<f32> {
let mut vec = Vec::with_capacity(self.length);
for i in 0..self.length {
let col = i / self.data.rows();
let row = i % self.data.rows();
vec.push(self.data[row][col]);
}
vec
}
}
#[cfg(target_feature = "avx2")]
impl<const C: usize> StripedScores<__m256, C> {
pub fn to_vec(&self) -> Vec<f32> {
// NOTE(@althonos): Because in AVX2 the __m256 vector is actually
// two independent __m128, the shuffling creates
// intrication in the results.
#[rustfmt::skip]
const COLS: &[usize] = &[
0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27,
4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31,
];
let mut col = 0;
let mut row = 0;
let mut vec = Vec::with_capacity(self.length);
for i in 0..self.length {
vec.push(self.data[row][COLS[col]]);
row += 1;
if row >= self.data.rows() {
row = 0;
col += 1;
}
}
vec
}
}
\ No newline at end of file
...@@ -166,20 +166,9 @@ pub struct WeightMatrix<A: Alphabet, const K: usize> { ...@@ -166,20 +166,9 @@ pub struct WeightMatrix<A: Alphabet, const K: usize> {
pub name: String, pub name: String,
} }
#[derive(Clone, Debug)] impl<A: Alphabet, const K: usize> WeightMatrix<A, K> {
pub struct StripedScores<const C: usize = 32> { /// The length of the motif encoded in this weight matrix.
pub length: usize, pub fn len(&self) -> usize {
pub data: DenseMatrix<f32, C>, self.data.rows()
}
impl<const C: usize> StripedScores<C> {
pub fn to_vec(&self) -> Vec<f32> {
let mut vec = Vec::with_capacity(self.length);
for i in 0..self.length {
let col = i / self.data.rows();
let row = i % self.data.rows();
vec.push(self.data[row][col]);
}
vec
} }
} }
use super::abc::Alphabet; use super::abc::Alphabet;
use super::abc::InvalidSymbol; use super::abc::InvalidSymbol;
use super::pwm::WeightMatrix;
use super::dense::DenseMatrix; use super::dense::DenseMatrix;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
...@@ -61,6 +62,12 @@ pub struct StripedSequence<A: Alphabet, const C: usize = 32> { ...@@ -61,6 +62,12 @@ pub struct StripedSequence<A: Alphabet, const C: usize = 32> {
} }
impl<A: Alphabet, const C: usize> StripedSequence<A, C> { impl<A: Alphabet, const C: usize> StripedSequence<A, C> {
/// Reconfigure the striped sequence for searching with a motif.
pub fn configure<const K: usize>(&mut self, motif: &WeightMatrix<A, K>) {
self.configure_wrap(motif.len());
}
/// Add wrap-around rows for a motif of length `m`. /// Add wrap-around rows for a motif of length `m`.
pub fn configure_wrap(&mut self, m: usize) { pub fn configure_wrap(&mut self, m: usize) {
if m > self.wrap { if m > self.wrap {
......
...@@ -15,16 +15,19 @@ const PATTERNS: &[&'static str] = &["GTTGACCTTATCAAC", "GTTGATCCAGTCAAC"]; ...@@ -15,16 +15,19 @@ const PATTERNS: &[&'static str] = &["GTTGACCTTATCAAC", "GTTGATCCAGTCAAC"];
// scores computed with Bio.motifs // scores computed with Bio.motifs
#[rustfmt::skip] #[rustfmt::skip]
const EXPECTED: &[f32] = &[ const EXPECTED: &[f32] = &[
-23.07094 , -18.678621 , -15.219191 , -17.745737 , -18.678621 , -23.07094 , -18.678621 , -15.219191 , -17.745737 ,
-23.07094 , -17.745737 , -19.611507 , -27.463257 , -29.989803 , -18.678621 , -23.07094 , -17.745737 , -19.611507 ,
-14.286304 , -26.53037 , -15.219191 , -10.826873 , -10.826873 , -27.463257 , -29.989803 , -14.286304 , -26.53037 ,
-22.138054 , -38.774437 , -30.922688 , -5.50167 , -24.003826 , -15.219191 , -10.826873 , -10.826873 , -22.138054 ,
-18.678621 , -15.219191 , -35.315006 , -17.745737 , -10.826873 , -38.774437 , -30.922688 , -5.50167 , -24.003826 ,
-30.922688 , -23.07094 , -6.4345555, -31.855574 , -23.07094 , -18.678621 , -15.219191 , -35.315006 , -17.745737 ,
-15.219191 , -31.855574 , -8.961102 , -26.53037 , -27.463257 , -10.826873 , -30.922688 , -23.07094 , -6.4345555,
-14.286304 , -15.219191 , -26.53037 , -23.07094 , -18.678621 , -31.855574 , -23.07094 , -15.219191 , -31.855574 ,
-14.286304 , -18.678621 , -26.53037 , -16.152077 , -17.745737 , -8.961102 , -26.53037 , -27.463257 , -14.286304 ,
-18.678621 , -17.745737 , -14.286304 , -30.922688 , -18.678621 -15.219191 , -26.53037 , -23.07094 , -18.678621 ,
-14.286304 , -18.678621 , -26.53037 , -16.152077 ,
-17.745737 , -18.678621 , -17.745737 , -14.286304 ,
-30.922688 , -18.678621
]; ];
#[test] #[test]
...@@ -62,9 +65,9 @@ fn test_score_generic() { ...@@ -62,9 +65,9 @@ fn test_score_generic() {
#[cfg(target_feature = "avx2")] #[cfg(target_feature = "avx2")]
#[test] #[test]
fn test_score_avx2() { fn test_score_avx2() {
let seq = "ATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCGCATCCAGCTCGCCCTCCAAGCTGTGGAGTCTGCTCCAGGAACCTCGCGACGTCCTCAAGAAGTTCGACCCAGACCGCCTCAACCTGAAACGATTCCATCATACCAACGGTGACACTCACGGTGCGACCGACGTCAACAACAAATCATATCTCCTCGAAGAAAACACCCGACTCTTCGATGCCTCGTTCTTCGGAATCAGCCCCCTGGAGGCGGCCGGTATGGACCCCCAGCAGCGTCTGTTGCTGGAAACCGTCTACGAGTCGTTTGAGGCGGCTGGCGTGACCCTCGATCAGCTCAAGGGTTCTTTGACCTCGGTTCATGTTGGCGTCATGACCAACGACTACTCCTTTATCCAGCTCCGTGACCCAGAAACGCTGTCGAAGTACAACGCGACTGGCACGGCCAACAGCATCATGTCGAACCGTATTTCATATGTCTTTGACTTGAAAGGTCCATCAGAGACCATCGACACGGCGTGCTCCAGCTCGCTGGTCGCCCTGCACCACGCTGCTCAGGGCCTGCTCAGCGGCGACTGCGAGACTGCCGTCGTCGCCGGCGTCAACCTCATCTTCGACCCCTCTCCATACATCACAGAGTCCAAGCTACACATGCTGTCACCCGACTCCCAGTCTCGCATGTGGGACAAGTCTGCAAATGGCTACGCCCGCGGCGAGGGCGCTGCCGCGCTGCTCCTGAAGCCCCTCAGCCGCGCCCTGAGGGACGGCGATCACATCGAGGGCATTGTCCGAGGCACAGGAGTCAACTCGGACGGCCAGAGCTCCGGCATCACCATGCCTTTTGCCCCTGCGCAGTCGGCGCTCATTCGCCAAACTTATCTCCGTGCTGGCCTCGACCCGATCAAGGACCGGCCTCAGTACTTCGAGTGCCACGGCACCGGAACTCCAGCTGGTGACCCCGTGGAAGCGCGAGCCATCAGCGAGTCGTTGTTGGACGGTGAAAATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCGCATCCAGCTCGCCCTCCAAGCTGTGGAGTCTGCTCCAGGAACCTCGCGACGTCCTCAAGAAGTTCGACCCAGACCGCCTCAACCTGAAACGATTCCATCATACCAACGGTGACACTCACGGTGCGACCGACGTCAACAACAAATCATATCTCCTCGAAGAAAACACCCGACTCTTCGATGCCTCGTTCTTCGGAATCAGCCCCCTGGAGGCGGCCGGTATGGACCCCCAGCAGCGTCTGTTGCTGGAAACCGTCTACGAGTCGTTTGAGGCGGCTGGCGTGACCCTCGATCAGCTCAAGGGTTCTTTGACCTCGGTTCATGTTGGCGTCATGACCAACGACTACTCCTTTATCCAGCTCCGTGACCCAGAAACGCTGTCGAAGTACAACGCGACTGGCACGGCCAACAGCATCATGTCGAACCGTATTTCATATGTCTTTGACTTGAAAGGTCCATCAGAGACCATCGACACGGCGTGCTCCAGCTCGCTGGTCGCCCTGCACCACGCTGCTCAGGGCCTGCTCAGCGGCGACTGCGAGACTGCCGTCGTCGCCGGCGTCAACCTCATCTTCGACCCCTCTCCATACATCACAGAGTCCAAGCTACACATGCTGTCACCCGACTCCCAGTCTCGCATGTGGGACAAGTCTGCAAATGGCTACGCCCGCGGCGAGGGCGCTGCCGCGCTGCTCCTGAAGCCCCTCAGCCGCGCCCTGAGGGACGGCGATCACATCGAGGGCATTGTCCGAGGCACAGGAGTCAACTCGGACGGCCAGAGCTCCGGCATCACCATGCCTTTTGCCCCTGCGCAGTCGGCGCTCATTCGCCAAACTTATCTCCGTGCTGGCCTCGACCCGATCAAGGACCGGCCTCAGTACTTCGAGTGCCACGGCACCGGAACTCCAGCTGGTGACCCCGTGGAAGCGCGAGCCATCAGCGAGTCGTTGTTGGACGGTGAAA"; // let seq = "ATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCGCATCCAGCTCGCCCTCCAAGCTGTGGAGTCTGCTCCAGGAACCTCGCGACGTCCTCAAGAAGTTCGACCCAGACCGCCTCAACCTGAAACGATTCCATCATACCAACGGTGACACTCACGGTGCGACCGACGTCAACAACAAATCATATCTCCTCGAAGAAAACACCCGACTCTTCGATGCCTCGTTCTTCGGAATCAGCCCCCTGGAGGCGGCCGGTATGGACCCCCAGCAGCGTCTGTTGCTGGAAACCGTCTACGAGTCGTTTGAGGCGGCTGGCGTGACCCTCGATCAGCTCAAGGGTTCTTTGACCTCGGTTCATGTTGGCGTCATGACCAACGACTACTCCTTTATCCAGCTCCGTGACCCAGAAACGCTGTCGAAGTACAACGCGACTGGCACGGCCAACAGCATCATGTCGAACCGTATTTCATATGTCTTTGACTTGAAAGGTCCATCAGAGACCATCGACACGGCGTGCTCCAGCTCGCTGGTCGCCCTGCACCACGCTGCTCAGGGCCTGCTCAGCGGCGACTGCGAGACTGCCGTCGTCGCCGGCGTCAACCTCATCTTCGACCCCTCTCCATACATCACAGAGTCCAAGCTACACATGCTGTCACCCGACTCCCAGTCTCGCATGTGGGACAAGTCTGCAAATGGCTACGCCCGCGGCGAGGGCGCTGCCGCGCTGCTCCTGAAGCCCCTCAGCCGCGCCCTGAGGGACGGCGATCACATCGAGGGCATTGTCCGAGGCACAGGAGTCAACTCGGACGGCCAGAGCTCCGGCATCACCATGCCTTTTGCCCCTGCGCAGTCGGCGCTCATTCGCCAAACTTATCTCCGTGCTGGCCTCGACCCGATCAAGGACCGGCCTCAGTACTTCGAGTGCCACGGCACCGGAACTCCAGCTGGTGACCCCGTGGAAGCGCGAGCCATCAGCGAGTCGTTGTTGGACGGTGAAAATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCGCATCCAGCTCGCCCTCCAAGCTGTGGAGTCTGCTCCAGGAACCTCGCGACGTCCTCAAGAAGTTCGACCCAGACCGCCTCAACCTGAAACGATTCCATCATACCAACGGTGACACTCACGGTGCGACCGACGTCAACAACAAATCATATCTCCTCGAAGAAAACACCCGACTCTTCGATGCCTCGTTCTTCGGAATCAGCCCCCTGGAGGCGGCCGGTATGGACCCCCAGCAGCGTCTGTTGCTGGAAACCGTCTACGAGTCGTTTGAGGCGGCTGGCGTGACCCTCGATCAGCTCAAGGGTTCTTTGACCTCGGTTCATGTTGGCGTCATGACCAACGACTACTCCTTTATCCAGCTCCGTGACCCAGAAACGCTGTCGAAGTACAACGCGACTGGCACGGCCAACAGCATCATGTCGAACCGTATTTCATATGTCTTTGACTTGAAAGGTCCATCAGAGACCATCGACACGGCGTGCTCCAGCTCGCTGGTCGCCCTGCACCACGCTGCTCAGGGCCTGCTCAGCGGCGACTGCGAGACTGCCGTCGTCGCCGGCGTCAACCTCATCTTCGACCCCTCTCCATACATCACAGAGTCCAAGCTACACATGCTGTCACCCGACTCCCAGTCTCGCATGTGGGACAAGTCTGCAAATGGCTACGCCCGCGGCGAGGGCGCTGCCGCGCTGCTCCTGAAGCCCCTCAGCCGCGCCCTGAGGGACGGCGATCACATCGAGGGCATTGTCCGAGGCACAGGAGTCAACTCGGACGGCCAGAGCTCCGGCATCACCATGCCTTTTGCCCCTGCGCAGTCGGCGCTCATTCGCCAAACTTATCTCCGTGCTGGCCTCGACCCGATCAAGGACCGGCCTCAGTACTTCGAGTGCCACGGCACCGGAACTCCAGCTGGTGACCCCGTGGAAGCGCGAGCCATCAGCGAGTCGTTGTTGGACGGTGAAA";
let encoded = EncodedSequence::<DnaAlphabet>::from_text(&seq[..]).unwrap(); let encoded = EncodedSequence::<DnaAlphabet>::from_text(SEQUENCE).unwrap();
let striped = encoded.to_striped::<{ std::mem::size_of::<__m256>() }>(); let mut striped = encoded.to_striped::<{ std::mem::size_of::<__m256>() }>();
let cm = CountMatrix::<DnaAlphabet, { DnaAlphabet::K }>::from_sequences( let cm = CountMatrix::<DnaAlphabet, { DnaAlphabet::K }>::from_sequences(
"MX000001", "MX000001",
...@@ -76,13 +79,12 @@ fn test_score_avx2() { ...@@ -76,13 +79,12 @@ fn test_score_avx2() {
let pbm = cm.to_probability([0.1, 0.1, 0.1, 0.1, 0.0]); let pbm = cm.to_probability([0.1, 0.1, 0.1, 0.1, 0.0]);
let pwm = pbm.to_weight([0.25, 0.25, 0.25, 0.25, 0.0]); let pwm = pbm.to_weight([0.25, 0.25, 0.25, 0.25, 0.0]);
striped.configure_wrap(pwm.data.rows() - 1);
let pli = Pipeline::<_, __m256>::new(); let pli = Pipeline::<_, __m256>::new();
let result = pli.score(&striped, &pwm); let result = pli.score(&striped, &pwm);
let scores = result.to_vec(); let scores = result.to_vec();
assert_eq!(scores.len(), EXPECTED.len());
// assert_eq!(scores.len(), seq.len() - cm.len() + 1);
// assert_eq!(scores[0], -23.07094); // -23.07094
for i in 0..EXPECTED.len() { for i in 0..EXPECTED.len() {
assert!( assert!(
(scores[i] - EXPECTED[i]).abs() < 1e-5, (scores[i] - EXPECTED[i]).abs() < 1e-5,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment