Skip to content
Snippets Groups Projects
Commit c6811bbd authored by Martin Larralde's avatar Martin Larralde
Browse files

Fix recovery of aligned scores in AVX2 implementation of the pipeline

parent 41113b90
No related branches found
No related tags found
No related merge requests found
......@@ -16,10 +16,10 @@ pub use abc::DnaSymbol;
pub use abc::Symbol;
pub use dense::DenseMatrix;
pub use pli::Pipeline;
pub use pli::StripedScores;
pub use pwm::Background;
pub use pwm::CountMatrix;
pub use pwm::ProbabilityMatrix;
pub use pwm::StripedScores;
pub use pwm::WeightMatrix;
pub use seq::EncodedSequence;
pub use seq::StripedSequence;
......@@ -6,7 +6,6 @@ use super::abc::Alphabet;
use super::abc::DnaAlphabet;
use super::abc::Symbol;
use super::dense::DenseMatrix;
use super::pwm::StripedScores;
use super::pwm::WeightMatrix;
use super::seq::EncodedSequence;
use super::seq::StripedSequence;
......@@ -39,11 +38,11 @@ impl Pipeline<DnaAlphabet, f32> {
&self,
seq: &StripedSequence<DnaAlphabet, C>,
pwm: &WeightMatrix<DnaAlphabet, { DnaAlphabet::K }>,
) -> StripedScores<C> {
) -> StripedScores<f32, C> {
let mut result = DenseMatrix::<f32, C>::new(seq.data.rows());
for i in 0..seq.length - pwm.data.rows() + 1 {
for i in 0..seq.length - pwm.len() + 1 {
let mut score = 0.0;
for j in 0..pwm.data.rows() {
for j in 0..pwm.len() {
let offset = i + j;
let col = offset / seq.data.rows();
let row = offset % seq.data.rows();
......@@ -54,8 +53,9 @@ impl Pipeline<DnaAlphabet, f32> {
result[row][col] = score;
}
StripedScores {
length: seq.length - pwm.data.rows() + 1,
length: seq.length - pwm.len() + 1,
data: result,
marker: std::marker::PhantomData,
}
}
}
......@@ -66,16 +66,18 @@ impl Pipeline<DnaAlphabet, __m256> {
&self,
seq: &StripedSequence<DnaAlphabet, { std::mem::size_of::<__m256i>() }>,
pwm: &WeightMatrix<DnaAlphabet, { DnaAlphabet::K }>,
) -> StripedScores<{ std::mem::size_of::<__m256i>() }> {
) -> StripedScores<__m256, { std::mem::size_of::<__m256i>() }> {
const S: i32 = std::mem::size_of::<f32>() as i32;
const C: usize = std::mem::size_of::<__m256i>();
const K: usize = DnaAlphabet::K;
let mut result = DenseMatrix::new((seq.length + C) / C);
if (seq.wrap < pwm.len() - 1) {
panic!("not enough wrapping rows for motif of length {}", pwm.len());
}
let mut result = DenseMatrix::new(seq.data.rows() - seq.wrap);
unsafe {
// get raw pointers to data
let sdata = seq.data[0].as_ptr();
let mdata = pwm.data[0].as_ptr();
// mask vectors for broadcasting:
let m1: __m256i = _mm256_set_epi32(
0xFFFFFF03u32 as i32,
......@@ -118,13 +120,13 @@ impl Pipeline<DnaAlphabet, __m256> {
0xFFFFFF0Cu32 as i32,
);
// loop over every row of the sequence data
for i in 0..seq.data.rows() - pwm.data.rows() + 1 {
for i in 0..seq.data.rows() - seq.wrap {
let mut s1 = _mm256_setzero_ps();
let mut s2 = _mm256_setzero_ps();
let mut s3 = _mm256_setzero_ps();
let mut s4 = _mm256_setzero_ps();
for j in 0..pwm.data.rows() {
let x = _mm256_load_si256(seq.data[i + j].as_ptr() as *const __m256i);
for j in 0..pwm.len() {
let x = _mm256_load_si256(seq.data[i+j].as_ptr() as *const __m256i);
let row = pwm.data[j].as_ptr();
// compute probabilities using an external lookup table
let p1 = _mm256_i32gather_ps(row, _mm256_shuffle_epi8(x, m1), S);
......@@ -137,19 +139,64 @@ impl Pipeline<DnaAlphabet, __m256> {
s3 = _mm256_add_ps(s3, p3);
s4 = _mm256_add_ps(s4, p4);
}
let row = &mut result[i];
let rowptr = row.as_mut_ptr();
_mm256_store_ps(rowptr, s1);
_mm256_store_ps(rowptr.add(0x08), s2);
_mm256_store_ps(rowptr.add(0x10), s3);
_mm256_store_ps(rowptr.add(0x18), s4);
_mm256_storeu_ps(row[0..].as_mut_ptr(), s1);
_mm256_storeu_ps(row[8..].as_mut_ptr(), s2);
_mm256_storeu_ps(row[16..].as_mut_ptr(), s3);
_mm256_storeu_ps(row[24..].as_mut_ptr(), s4);
}
}
StripedScores {
length: seq.length - pwm.data.rows() + 1,
length: seq.length - pwm.len() + 1,
data: result,
marker: std::marker::PhantomData,
}
}
}
#[derive(Clone, Debug)]
pub struct StripedScores<V: Vector, const C: usize = 32> {
pub length: usize,
pub data: DenseMatrix<f32, C>,
marker: std::marker::PhantomData<V>,
}
impl<const C: usize> StripedScores<f32, C> {
pub fn to_vec(&self) -> Vec<f32> {
let mut vec = Vec::with_capacity(self.length);
for i in 0..self.length {
let col = i / self.data.rows();
let row = i % self.data.rows();
vec.push(self.data[row][col]);
}
vec
}
}
#[cfg(target_feature = "avx2")]
impl<const C: usize> StripedScores<__m256, C> {
pub fn to_vec(&self) -> Vec<f32> {
// NOTE(@althonos): Because in AVX2 the __m256 vector is actually
// two independent __m128, the shuffling creates
// intrication in the results.
#[rustfmt::skip]
const COLS: &[usize] = &[
0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27,
4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31,
];
let mut col = 0;
let mut row = 0;
let mut vec = Vec::with_capacity(self.length);
for i in 0..self.length {
vec.push(self.data[row][COLS[col]]);
row += 1;
if row >= self.data.rows() {
row = 0;
col += 1;
}
}
vec
}
}
\ No newline at end of file
......@@ -166,20 +166,9 @@ pub struct WeightMatrix<A: Alphabet, const K: usize> {
pub name: String,
}
#[derive(Clone, Debug)]
pub struct StripedScores<const C: usize = 32> {
pub length: usize,
pub data: DenseMatrix<f32, C>,
}
impl<const C: usize> StripedScores<C> {
pub fn to_vec(&self) -> Vec<f32> {
let mut vec = Vec::with_capacity(self.length);
for i in 0..self.length {
let col = i / self.data.rows();
let row = i % self.data.rows();
vec.push(self.data[row][col]);
}
vec
impl<A: Alphabet, const K: usize> WeightMatrix<A, K> {
/// The length of the motif encoded in this weight matrix.
pub fn len(&self) -> usize {
self.data.rows()
}
}
use super::abc::Alphabet;
use super::abc::InvalidSymbol;
use super::pwm::WeightMatrix;
use super::dense::DenseMatrix;
#[derive(Clone, Debug)]
......@@ -61,6 +62,12 @@ pub struct StripedSequence<A: Alphabet, const C: usize = 32> {
}
impl<A: Alphabet, const C: usize> StripedSequence<A, C> {
/// Reconfigure the striped sequence for searching with a motif.
pub fn configure<const K: usize>(&mut self, motif: &WeightMatrix<A, K>) {
self.configure_wrap(motif.len());
}
/// Add wrap-around rows for a motif of length `m`.
pub fn configure_wrap(&mut self, m: usize) {
if m > self.wrap {
......
......@@ -15,16 +15,19 @@ const PATTERNS: &[&'static str] = &["GTTGACCTTATCAAC", "GTTGATCCAGTCAAC"];
// scores computed with Bio.motifs
#[rustfmt::skip]
const EXPECTED: &[f32] = &[
-23.07094 , -18.678621 , -15.219191 , -17.745737 , -18.678621 ,
-23.07094 , -17.745737 , -19.611507 , -27.463257 , -29.989803 ,
-14.286304 , -26.53037 , -15.219191 , -10.826873 , -10.826873 ,
-22.138054 , -38.774437 , -30.922688 , -5.50167 , -24.003826 ,
-18.678621 , -15.219191 , -35.315006 , -17.745737 , -10.826873 ,
-30.922688 , -23.07094 , -6.4345555, -31.855574 , -23.07094 ,
-15.219191 , -31.855574 , -8.961102 , -26.53037 , -27.463257 ,
-14.286304 , -15.219191 , -26.53037 , -23.07094 , -18.678621 ,
-14.286304 , -18.678621 , -26.53037 , -16.152077 , -17.745737 ,
-18.678621 , -17.745737 , -14.286304 , -30.922688 , -18.678621
-23.07094 , -18.678621 , -15.219191 , -17.745737 ,
-18.678621 , -23.07094 , -17.745737 , -19.611507 ,
-27.463257 , -29.989803 , -14.286304 , -26.53037 ,
-15.219191 , -10.826873 , -10.826873 , -22.138054 ,
-38.774437 , -30.922688 , -5.50167 , -24.003826 ,
-18.678621 , -15.219191 , -35.315006 , -17.745737 ,
-10.826873 , -30.922688 , -23.07094 , -6.4345555,
-31.855574 , -23.07094 , -15.219191 , -31.855574 ,
-8.961102 , -26.53037 , -27.463257 , -14.286304 ,
-15.219191 , -26.53037 , -23.07094 , -18.678621 ,
-14.286304 , -18.678621 , -26.53037 , -16.152077 ,
-17.745737 , -18.678621 , -17.745737 , -14.286304 ,
-30.922688 , -18.678621
];
#[test]
......@@ -62,9 +65,9 @@ fn test_score_generic() {
#[cfg(target_feature = "avx2")]
#[test]
fn test_score_avx2() {
let seq = "ATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCGCATCCAGCTCGCCCTCCAAGCTGTGGAGTCTGCTCCAGGAACCTCGCGACGTCCTCAAGAAGTTCGACCCAGACCGCCTCAACCTGAAACGATTCCATCATACCAACGGTGACACTCACGGTGCGACCGACGTCAACAACAAATCATATCTCCTCGAAGAAAACACCCGACTCTTCGATGCCTCGTTCTTCGGAATCAGCCCCCTGGAGGCGGCCGGTATGGACCCCCAGCAGCGTCTGTTGCTGGAAACCGTCTACGAGTCGTTTGAGGCGGCTGGCGTGACCCTCGATCAGCTCAAGGGTTCTTTGACCTCGGTTCATGTTGGCGTCATGACCAACGACTACTCCTTTATCCAGCTCCGTGACCCAGAAACGCTGTCGAAGTACAACGCGACTGGCACGGCCAACAGCATCATGTCGAACCGTATTTCATATGTCTTTGACTTGAAAGGTCCATCAGAGACCATCGACACGGCGTGCTCCAGCTCGCTGGTCGCCCTGCACCACGCTGCTCAGGGCCTGCTCAGCGGCGACTGCGAGACTGCCGTCGTCGCCGGCGTCAACCTCATCTTCGACCCCTCTCCATACATCACAGAGTCCAAGCTACACATGCTGTCACCCGACTCCCAGTCTCGCATGTGGGACAAGTCTGCAAATGGCTACGCCCGCGGCGAGGGCGCTGCCGCGCTGCTCCTGAAGCCCCTCAGCCGCGCCCTGAGGGACGGCGATCACATCGAGGGCATTGTCCGAGGCACAGGAGTCAACTCGGACGGCCAGAGCTCCGGCATCACCATGCCTTTTGCCCCTGCGCAGTCGGCGCTCATTCGCCAAACTTATCTCCGTGCTGGCCTCGACCCGATCAAGGACCGGCCTCAGTACTTCGAGTGCCACGGCACCGGAACTCCAGCTGGTGACCCCGTGGAAGCGCGAGCCATCAGCGAGTCGTTGTTGGACGGTGAAAATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCGCATCCAGCTCGCCCTCCAAGCTGTGGAGTCTGCTCCAGGAACCTCGCGACGTCCTCAAGAAGTTCGACCCAGACCGCCTCAACCTGAAACGATTCCATCATACCAACGGTGACACTCACGGTGCGACCGACGTCAACAACAAATCATATCTCCTCGAAGAAAACACCCGACTCTTCGATGCCTCGTTCTTCGGAATCAGCCCCCTGGAGGCGGCCGGTATGGACCCCCAGCAGCGTCTGTTGCTGGAAACCGTCTACGAGTCGTTTGAGGCGGCTGGCGTGACCCTCGATCAGCTCAAGGGTTCTTTGACCTCGGTTCATGTTGGCGTCATGACCAACGACTACTCCTTTATCCAGCTCCGTGACCCAGAAACGCTGTCGAAGTACAACGCGACTGGCACGGCCAACAGCATCATGTCGAACCGTATTTCATATGTCTTTGACTTGAAAGGTCCATCAGAGACCATCGACACGGCGTGCTCCAGCTCGCTGGTCGCCCTGCACCACGCTGCTCAGGGCCTGCTCAGCGGCGACTGCGAGACTGCCGTCGTCGCCGGCGTCAACCTCATCTTCGACCCCTCTCCATACATCACAGAGTCCAAGCTACACATGCTGTCACCCGACTCCCAGTCTCGCATGTGGGACAAGTCTGCAAATGGCTACGCCCGCGGCGAGGGCGCTGCCGCGCTGCTCCTGAAGCCCCTCAGCCGCGCCCTGAGGGACGGCGATCACATCGAGGGCATTGTCCGAGGCACAGGAGTCAACTCGGACGGCCAGAGCTCCGGCATCACCATGCCTTTTGCCCCTGCGCAGTCGGCGCTCATTCGCCAAACTTATCTCCGTGCTGGCCTCGACCCGATCAAGGACCGGCCTCAGTACTTCGAGTGCCACGGCACCGGAACTCCAGCTGGTGACCCCGTGGAAGCGCGAGCCATCAGCGAGTCGTTGTTGGACGGTGAAA";
let encoded = EncodedSequence::<DnaAlphabet>::from_text(&seq[..]).unwrap();
let striped = encoded.to_striped::<{ std::mem::size_of::<__m256>() }>();
// let seq = "ATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCGCATCCAGCTCGCCCTCCAAGCTGTGGAGTCTGCTCCAGGAACCTCGCGACGTCCTCAAGAAGTTCGACCCAGACCGCCTCAACCTGAAACGATTCCATCATACCAACGGTGACACTCACGGTGCGACCGACGTCAACAACAAATCATATCTCCTCGAAGAAAACACCCGACTCTTCGATGCCTCGTTCTTCGGAATCAGCCCCCTGGAGGCGGCCGGTATGGACCCCCAGCAGCGTCTGTTGCTGGAAACCGTCTACGAGTCGTTTGAGGCGGCTGGCGTGACCCTCGATCAGCTCAAGGGTTCTTTGACCTCGGTTCATGTTGGCGTCATGACCAACGACTACTCCTTTATCCAGCTCCGTGACCCAGAAACGCTGTCGAAGTACAACGCGACTGGCACGGCCAACAGCATCATGTCGAACCGTATTTCATATGTCTTTGACTTGAAAGGTCCATCAGAGACCATCGACACGGCGTGCTCCAGCTCGCTGGTCGCCCTGCACCACGCTGCTCAGGGCCTGCTCAGCGGCGACTGCGAGACTGCCGTCGTCGCCGGCGTCAACCTCATCTTCGACCCCTCTCCATACATCACAGAGTCCAAGCTACACATGCTGTCACCCGACTCCCAGTCTCGCATGTGGGACAAGTCTGCAAATGGCTACGCCCGCGGCGAGGGCGCTGCCGCGCTGCTCCTGAAGCCCCTCAGCCGCGCCCTGAGGGACGGCGATCACATCGAGGGCATTGTCCGAGGCACAGGAGTCAACTCGGACGGCCAGAGCTCCGGCATCACCATGCCTTTTGCCCCTGCGCAGTCGGCGCTCATTCGCCAAACTTATCTCCGTGCTGGCCTCGACCCGATCAAGGACCGGCCTCAGTACTTCGAGTGCCACGGCACCGGAACTCCAGCTGGTGACCCCGTGGAAGCGCGAGCCATCAGCGAGTCGTTGTTGGACGGTGAAAATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCGCATCCAGCTCGCCCTCCAAGCTGTGGAGTCTGCTCCAGGAACCTCGCGACGTCCTCAAGAAGTTCGACCCAGACCGCCTCAACCTGAAACGATTCCATCATACCAACGGTGACACTCACGGTGCGACCGACGTCAACAACAAATCATATCTCCTCGAAGAAAACACCCGACTCTTCGATGCCTCGTTCTTCGGAATCAGCCCCCTGGAGGCGGCCGGTATGGACCCCCAGCAGCGTCTGTTGCTGGAAACCGTCTACGAGTCGTTTGAGGCGGCTGGCGTGACCCTCGATCAGCTCAAGGGTTCTTTGACCTCGGTTCATGTTGGCGTCATGACCAACGACTACTCCTTTATCCAGCTCCGTGACCCAGAAACGCTGTCGAAGTACAACGCGACTGGCACGGCCAACAGCATCATGTCGAACCGTATTTCATATGTCTTTGACTTGAAAGGTCCATCAGAGACCATCGACACGGCGTGCTCCAGCTCGCTGGTCGCCCTGCACCACGCTGCTCAGGGCCTGCTCAGCGGCGACTGCGAGACTGCCGTCGTCGCCGGCGTCAACCTCATCTTCGACCCCTCTCCATACATCACAGAGTCCAAGCTACACATGCTGTCACCCGACTCCCAGTCTCGCATGTGGGACAAGTCTGCAAATGGCTACGCCCGCGGCGAGGGCGCTGCCGCGCTGCTCCTGAAGCCCCTCAGCCGCGCCCTGAGGGACGGCGATCACATCGAGGGCATTGTCCGAGGCACAGGAGTCAACTCGGACGGCCAGAGCTCCGGCATCACCATGCCTTTTGCCCCTGCGCAGTCGGCGCTCATTCGCCAAACTTATCTCCGTGCTGGCCTCGACCCGATCAAGGACCGGCCTCAGTACTTCGAGTGCCACGGCACCGGAACTCCAGCTGGTGACCCCGTGGAAGCGCGAGCCATCAGCGAGTCGTTGTTGGACGGTGAAA";
let encoded = EncodedSequence::<DnaAlphabet>::from_text(SEQUENCE).unwrap();
let mut striped = encoded.to_striped::<{ std::mem::size_of::<__m256>() }>();
let cm = CountMatrix::<DnaAlphabet, { DnaAlphabet::K }>::from_sequences(
"MX000001",
......@@ -76,13 +79,12 @@ fn test_score_avx2() {
let pbm = cm.to_probability([0.1, 0.1, 0.1, 0.1, 0.0]);
let pwm = pbm.to_weight([0.25, 0.25, 0.25, 0.25, 0.0]);
striped.configure_wrap(pwm.data.rows() - 1);
let pli = Pipeline::<_, __m256>::new();
let result = pli.score(&striped, &pwm);
let scores = result.to_vec();
// assert_eq!(scores.len(), seq.len() - cm.len() + 1);
// assert_eq!(scores[0], -23.07094); // -23.07094
assert_eq!(scores.len(), EXPECTED.len());
for i in 0..EXPECTED.len() {
assert!(
(scores[i] - EXPECTED[i]).abs() < 1e-5,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment