diff --git a/src/lib.rs b/src/lib.rs index dd7d4708431bb167f4c986d2cfc682d5c4447d6b..09a9985b7db6ec61588b73dd50360ef91a3e00c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,10 +16,10 @@ pub use abc::DnaSymbol; pub use abc::Symbol; pub use dense::DenseMatrix; pub use pli::Pipeline; +pub use pli::StripedScores; pub use pwm::Background; pub use pwm::CountMatrix; pub use pwm::ProbabilityMatrix; -pub use pwm::StripedScores; pub use pwm::WeightMatrix; pub use seq::EncodedSequence; pub use seq::StripedSequence; diff --git a/src/pli.rs b/src/pli.rs index 74f6804763e4461a8d3a5218a23506933e00b576..a0b64d3b859c25d070a6693c2577483232a457be 100644 --- a/src/pli.rs +++ b/src/pli.rs @@ -6,7 +6,6 @@ use super::abc::Alphabet; use super::abc::DnaAlphabet; use super::abc::Symbol; use super::dense::DenseMatrix; -use super::pwm::StripedScores; use super::pwm::WeightMatrix; use super::seq::EncodedSequence; use super::seq::StripedSequence; @@ -39,11 +38,11 @@ impl Pipeline<DnaAlphabet, f32> { &self, seq: &StripedSequence<DnaAlphabet, C>, pwm: &WeightMatrix<DnaAlphabet, { DnaAlphabet::K }>, - ) -> StripedScores<C> { + ) -> StripedScores<f32, C> { let mut result = DenseMatrix::<f32, C>::new(seq.data.rows()); - for i in 0..seq.length - pwm.data.rows() + 1 { + for i in 0..seq.length - pwm.len() + 1 { let mut score = 0.0; - for j in 0..pwm.data.rows() { + for j in 0..pwm.len() { let offset = i + j; let col = offset / seq.data.rows(); let row = offset % seq.data.rows(); @@ -54,8 +53,9 @@ impl Pipeline<DnaAlphabet, f32> { result[row][col] = score; } StripedScores { - length: seq.length - pwm.data.rows() + 1, + length: seq.length - pwm.len() + 1, data: result, + marker: std::marker::PhantomData, } } } @@ -66,16 +66,18 @@ impl Pipeline<DnaAlphabet, __m256> { &self, seq: &StripedSequence<DnaAlphabet, { std::mem::size_of::<__m256i>() }>, pwm: &WeightMatrix<DnaAlphabet, { DnaAlphabet::K }>, - ) -> StripedScores<{ std::mem::size_of::<__m256i>() }> { + ) -> StripedScores<__m256, { std::mem::size_of::<__m256i>() }> { const S: i32 = std::mem::size_of::<f32>() as i32; const C: usize = std::mem::size_of::<__m256i>(); const K: usize = DnaAlphabet::K; - let mut result = DenseMatrix::new((seq.length + C) / C); + if (seq.wrap < pwm.len() - 1) { + panic!("not enough wrapping rows for motif of length {}", pwm.len()); + } + + let mut result = DenseMatrix::new(seq.data.rows() - seq.wrap); unsafe { // get raw pointers to data - let sdata = seq.data[0].as_ptr(); - let mdata = pwm.data[0].as_ptr(); // mask vectors for broadcasting: let m1: __m256i = _mm256_set_epi32( 0xFFFFFF03u32 as i32, @@ -118,13 +120,13 @@ impl Pipeline<DnaAlphabet, __m256> { 0xFFFFFF0Cu32 as i32, ); // loop over every row of the sequence data - for i in 0..seq.data.rows() - pwm.data.rows() + 1 { + for i in 0..seq.data.rows() - seq.wrap { let mut s1 = _mm256_setzero_ps(); let mut s2 = _mm256_setzero_ps(); let mut s3 = _mm256_setzero_ps(); let mut s4 = _mm256_setzero_ps(); - for j in 0..pwm.data.rows() { - let x = _mm256_load_si256(seq.data[i + j].as_ptr() as *const __m256i); + for j in 0..pwm.len() { + let x = _mm256_load_si256(seq.data[i+j].as_ptr() as *const __m256i); let row = pwm.data[j].as_ptr(); // compute probabilities using an external lookup table let p1 = _mm256_i32gather_ps(row, _mm256_shuffle_epi8(x, m1), S); @@ -137,19 +139,64 @@ impl Pipeline<DnaAlphabet, __m256> { s3 = _mm256_add_ps(s3, p3); s4 = _mm256_add_ps(s4, p4); } - let row = &mut result[i]; - let rowptr = row.as_mut_ptr(); - _mm256_store_ps(rowptr, s1); - _mm256_store_ps(rowptr.add(0x08), s2); - _mm256_store_ps(rowptr.add(0x10), s3); - _mm256_store_ps(rowptr.add(0x18), s4); + _mm256_storeu_ps(row[0..].as_mut_ptr(), s1); + _mm256_storeu_ps(row[8..].as_mut_ptr(), s2); + _mm256_storeu_ps(row[16..].as_mut_ptr(), s3); + _mm256_storeu_ps(row[24..].as_mut_ptr(), s4); } } StripedScores { - length: seq.length - pwm.data.rows() + 1, + length: seq.length - pwm.len() + 1, data: result, + marker: std::marker::PhantomData, } } } + +#[derive(Clone, Debug)] +pub struct StripedScores<V: Vector, const C: usize = 32> { + pub length: usize, + pub data: DenseMatrix<f32, C>, + marker: std::marker::PhantomData<V>, +} + +impl<const C: usize> StripedScores<f32, C> { + pub fn to_vec(&self) -> Vec<f32> { + let mut vec = Vec::with_capacity(self.length); + for i in 0..self.length { + let col = i / self.data.rows(); + let row = i % self.data.rows(); + vec.push(self.data[row][col]); + } + vec + } +} + +#[cfg(target_feature = "avx2")] +impl<const C: usize> StripedScores<__m256, C> { + pub fn to_vec(&self) -> Vec<f32> { + // NOTE(@althonos): Because in AVX2 the __m256 vector is actually + // two independent __m128, the shuffling creates + // intrication in the results. + #[rustfmt::skip] + const COLS: &[usize] = &[ + 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, + 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31, + ]; + + let mut col = 0; + let mut row = 0; + let mut vec = Vec::with_capacity(self.length); + for i in 0..self.length { + vec.push(self.data[row][COLS[col]]); + row += 1; + if row >= self.data.rows() { + row = 0; + col += 1; + } + } + vec + } +} \ No newline at end of file diff --git a/src/pwm.rs b/src/pwm.rs index 1ffe66fcb85762886fdfbc4c16c7b3ea1b69934f..e6ed40de6475b356fd1c450019c4e64071230d76 100644 --- a/src/pwm.rs +++ b/src/pwm.rs @@ -166,20 +166,9 @@ pub struct WeightMatrix<A: Alphabet, const K: usize> { pub name: String, } -#[derive(Clone, Debug)] -pub struct StripedScores<const C: usize = 32> { - pub length: usize, - pub data: DenseMatrix<f32, C>, -} - -impl<const C: usize> StripedScores<C> { - pub fn to_vec(&self) -> Vec<f32> { - let mut vec = Vec::with_capacity(self.length); - for i in 0..self.length { - let col = i / self.data.rows(); - let row = i % self.data.rows(); - vec.push(self.data[row][col]); - } - vec +impl<A: Alphabet, const K: usize> WeightMatrix<A, K> { + /// The length of the motif encoded in this weight matrix. + pub fn len(&self) -> usize { + self.data.rows() } } diff --git a/src/seq.rs b/src/seq.rs index 256c5cfcd460184523f5610bf215713ad4e169be..5f09c0812e9cc06781a8a7a5b6800be45ac8f0b5 100644 --- a/src/seq.rs +++ b/src/seq.rs @@ -1,5 +1,6 @@ use super::abc::Alphabet; use super::abc::InvalidSymbol; +use super::pwm::WeightMatrix; use super::dense::DenseMatrix; #[derive(Clone, Debug)] @@ -61,6 +62,12 @@ pub struct StripedSequence<A: Alphabet, const C: usize = 32> { } impl<A: Alphabet, const C: usize> StripedSequence<A, C> { + + /// Reconfigure the striped sequence for searching with a motif. + pub fn configure<const K: usize>(&mut self, motif: &WeightMatrix<A, K>) { + self.configure_wrap(motif.len()); + } + /// Add wrap-around rows for a motif of length `m`. pub fn configure_wrap(&mut self, m: usize) { if m > self.wrap { diff --git a/tests/dna.rs b/tests/dna.rs index cd81acd0ac7f78eb19e146ecaaafc97172eeb145..badaf2c3494092c2748dbe9638a691cbe9c0608e 100644 --- a/tests/dna.rs +++ b/tests/dna.rs @@ -15,16 +15,19 @@ const PATTERNS: &[&'static str] = &["GTTGACCTTATCAAC", "GTTGATCCAGTCAAC"]; // scores computed with Bio.motifs #[rustfmt::skip] const EXPECTED: &[f32] = &[ - -23.07094 , -18.678621 , -15.219191 , -17.745737 , -18.678621 , - -23.07094 , -17.745737 , -19.611507 , -27.463257 , -29.989803 , - -14.286304 , -26.53037 , -15.219191 , -10.826873 , -10.826873 , - -22.138054 , -38.774437 , -30.922688 , -5.50167 , -24.003826 , - -18.678621 , -15.219191 , -35.315006 , -17.745737 , -10.826873 , - -30.922688 , -23.07094 , -6.4345555, -31.855574 , -23.07094 , - -15.219191 , -31.855574 , -8.961102 , -26.53037 , -27.463257 , - -14.286304 , -15.219191 , -26.53037 , -23.07094 , -18.678621 , - -14.286304 , -18.678621 , -26.53037 , -16.152077 , -17.745737 , - -18.678621 , -17.745737 , -14.286304 , -30.922688 , -18.678621 + -23.07094 , -18.678621 , -15.219191 , -17.745737 , + -18.678621 , -23.07094 , -17.745737 , -19.611507 , + -27.463257 , -29.989803 , -14.286304 , -26.53037 , + -15.219191 , -10.826873 , -10.826873 , -22.138054 , + -38.774437 , -30.922688 , -5.50167 , -24.003826 , + -18.678621 , -15.219191 , -35.315006 , -17.745737 , + -10.826873 , -30.922688 , -23.07094 , -6.4345555, + -31.855574 , -23.07094 , -15.219191 , -31.855574 , + -8.961102 , -26.53037 , -27.463257 , -14.286304 , + -15.219191 , -26.53037 , -23.07094 , -18.678621 , + -14.286304 , -18.678621 , -26.53037 , -16.152077 , + -17.745737 , -18.678621 , -17.745737 , -14.286304 , + -30.922688 , -18.678621 ]; #[test] @@ -62,9 +65,9 @@ fn test_score_generic() { #[cfg(target_feature = "avx2")] #[test] fn test_score_avx2() { - let seq = "ATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCGCATCCAGCTCGCCCTCCAAGCTGTGGAGTCTGCTCCAGGAACCTCGCGACGTCCTCAAGAAGTTCGACCCAGACCGCCTCAACCTGAAACGATTCCATCATACCAACGGTGACACTCACGGTGCGACCGACGTCAACAACAAATCATATCTCCTCGAAGAAAACACCCGACTCTTCGATGCCTCGTTCTTCGGAATCAGCCCCCTGGAGGCGGCCGGTATGGACCCCCAGCAGCGTCTGTTGCTGGAAACCGTCTACGAGTCGTTTGAGGCGGCTGGCGTGACCCTCGATCAGCTCAAGGGTTCTTTGACCTCGGTTCATGTTGGCGTCATGACCAACGACTACTCCTTTATCCAGCTCCGTGACCCAGAAACGCTGTCGAAGTACAACGCGACTGGCACGGCCAACAGCATCATGTCGAACCGTATTTCATATGTCTTTGACTTGAAAGGTCCATCAGAGACCATCGACACGGCGTGCTCCAGCTCGCTGGTCGCCCTGCACCACGCTGCTCAGGGCCTGCTCAGCGGCGACTGCGAGACTGCCGTCGTCGCCGGCGTCAACCTCATCTTCGACCCCTCTCCATACATCACAGAGTCCAAGCTACACATGCTGTCACCCGACTCCCAGTCTCGCATGTGGGACAAGTCTGCAAATGGCTACGCCCGCGGCGAGGGCGCTGCCGCGCTGCTCCTGAAGCCCCTCAGCCGCGCCCTGAGGGACGGCGATCACATCGAGGGCATTGTCCGAGGCACAGGAGTCAACTCGGACGGCCAGAGCTCCGGCATCACCATGCCTTTTGCCCCTGCGCAGTCGGCGCTCATTCGCCAAACTTATCTCCGTGCTGGCCTCGACCCGATCAAGGACCGGCCTCAGTACTTCGAGTGCCACGGCACCGGAACTCCAGCTGGTGACCCCGTGGAAGCGCGAGCCATCAGCGAGTCGTTGTTGGACGGTGAAAATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCGCATCCAGCTCGCCCTCCAAGCTGTGGAGTCTGCTCCAGGAACCTCGCGACGTCCTCAAGAAGTTCGACCCAGACCGCCTCAACCTGAAACGATTCCATCATACCAACGGTGACACTCACGGTGCGACCGACGTCAACAACAAATCATATCTCCTCGAAGAAAACACCCGACTCTTCGATGCCTCGTTCTTCGGAATCAGCCCCCTGGAGGCGGCCGGTATGGACCCCCAGCAGCGTCTGTTGCTGGAAACCGTCTACGAGTCGTTTGAGGCGGCTGGCGTGACCCTCGATCAGCTCAAGGGTTCTTTGACCTCGGTTCATGTTGGCGTCATGACCAACGACTACTCCTTTATCCAGCTCCGTGACCCAGAAACGCTGTCGAAGTACAACGCGACTGGCACGGCCAACAGCATCATGTCGAACCGTATTTCATATGTCTTTGACTTGAAAGGTCCATCAGAGACCATCGACACGGCGTGCTCCAGCTCGCTGGTCGCCCTGCACCACGCTGCTCAGGGCCTGCTCAGCGGCGACTGCGAGACTGCCGTCGTCGCCGGCGTCAACCTCATCTTCGACCCCTCTCCATACATCACAGAGTCCAAGCTACACATGCTGTCACCCGACTCCCAGTCTCGCATGTGGGACAAGTCTGCAAATGGCTACGCCCGCGGCGAGGGCGCTGCCGCGCTGCTCCTGAAGCCCCTCAGCCGCGCCCTGAGGGACGGCGATCACATCGAGGGCATTGTCCGAGGCACAGGAGTCAACTCGGACGGCCAGAGCTCCGGCATCACCATGCCTTTTGCCCCTGCGCAGTCGGCGCTCATTCGCCAAACTTATCTCCGTGCTGGCCTCGACCCGATCAAGGACCGGCCTCAGTACTTCGAGTGCCACGGCACCGGAACTCCAGCTGGTGACCCCGTGGAAGCGCGAGCCATCAGCGAGTCGTTGTTGGACGGTGAAA"; - let encoded = EncodedSequence::<DnaAlphabet>::from_text(&seq[..]).unwrap(); - let striped = encoded.to_striped::<{ std::mem::size_of::<__m256>() }>(); + // let seq = "ATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCGCATCCAGCTCGCCCTCCAAGCTGTGGAGTCTGCTCCAGGAACCTCGCGACGTCCTCAAGAAGTTCGACCCAGACCGCCTCAACCTGAAACGATTCCATCATACCAACGGTGACACTCACGGTGCGACCGACGTCAACAACAAATCATATCTCCTCGAAGAAAACACCCGACTCTTCGATGCCTCGTTCTTCGGAATCAGCCCCCTGGAGGCGGCCGGTATGGACCCCCAGCAGCGTCTGTTGCTGGAAACCGTCTACGAGTCGTTTGAGGCGGCTGGCGTGACCCTCGATCAGCTCAAGGGTTCTTTGACCTCGGTTCATGTTGGCGTCATGACCAACGACTACTCCTTTATCCAGCTCCGTGACCCAGAAACGCTGTCGAAGTACAACGCGACTGGCACGGCCAACAGCATCATGTCGAACCGTATTTCATATGTCTTTGACTTGAAAGGTCCATCAGAGACCATCGACACGGCGTGCTCCAGCTCGCTGGTCGCCCTGCACCACGCTGCTCAGGGCCTGCTCAGCGGCGACTGCGAGACTGCCGTCGTCGCCGGCGTCAACCTCATCTTCGACCCCTCTCCATACATCACAGAGTCCAAGCTACACATGCTGTCACCCGACTCCCAGTCTCGCATGTGGGACAAGTCTGCAAATGGCTACGCCCGCGGCGAGGGCGCTGCCGCGCTGCTCCTGAAGCCCCTCAGCCGCGCCCTGAGGGACGGCGATCACATCGAGGGCATTGTCCGAGGCACAGGAGTCAACTCGGACGGCCAGAGCTCCGGCATCACCATGCCTTTTGCCCCTGCGCAGTCGGCGCTCATTCGCCAAACTTATCTCCGTGCTGGCCTCGACCCGATCAAGGACCGGCCTCAGTACTTCGAGTGCCACGGCACCGGAACTCCAGCTGGTGACCCCGTGGAAGCGCGAGCCATCAGCGAGTCGTTGTTGGACGGTGAAAATGTCCCAACAACGATACCCCGAGCCCATCGCCGTCATCGGCTCGGCATGCAGATTCCCAGGCGCATCCAGCTCGCCCTCCAAGCTGTGGAGTCTGCTCCAGGAACCTCGCGACGTCCTCAAGAAGTTCGACCCAGACCGCCTCAACCTGAAACGATTCCATCATACCAACGGTGACACTCACGGTGCGACCGACGTCAACAACAAATCATATCTCCTCGAAGAAAACACCCGACTCTTCGATGCCTCGTTCTTCGGAATCAGCCCCCTGGAGGCGGCCGGTATGGACCCCCAGCAGCGTCTGTTGCTGGAAACCGTCTACGAGTCGTTTGAGGCGGCTGGCGTGACCCTCGATCAGCTCAAGGGTTCTTTGACCTCGGTTCATGTTGGCGTCATGACCAACGACTACTCCTTTATCCAGCTCCGTGACCCAGAAACGCTGTCGAAGTACAACGCGACTGGCACGGCCAACAGCATCATGTCGAACCGTATTTCATATGTCTTTGACTTGAAAGGTCCATCAGAGACCATCGACACGGCGTGCTCCAGCTCGCTGGTCGCCCTGCACCACGCTGCTCAGGGCCTGCTCAGCGGCGACTGCGAGACTGCCGTCGTCGCCGGCGTCAACCTCATCTTCGACCCCTCTCCATACATCACAGAGTCCAAGCTACACATGCTGTCACCCGACTCCCAGTCTCGCATGTGGGACAAGTCTGCAAATGGCTACGCCCGCGGCGAGGGCGCTGCCGCGCTGCTCCTGAAGCCCCTCAGCCGCGCCCTGAGGGACGGCGATCACATCGAGGGCATTGTCCGAGGCACAGGAGTCAACTCGGACGGCCAGAGCTCCGGCATCACCATGCCTTTTGCCCCTGCGCAGTCGGCGCTCATTCGCCAAACTTATCTCCGTGCTGGCCTCGACCCGATCAAGGACCGGCCTCAGTACTTCGAGTGCCACGGCACCGGAACTCCAGCTGGTGACCCCGTGGAAGCGCGAGCCATCAGCGAGTCGTTGTTGGACGGTGAAA"; + let encoded = EncodedSequence::<DnaAlphabet>::from_text(SEQUENCE).unwrap(); + let mut striped = encoded.to_striped::<{ std::mem::size_of::<__m256>() }>(); let cm = CountMatrix::<DnaAlphabet, { DnaAlphabet::K }>::from_sequences( "MX000001", @@ -76,13 +79,12 @@ fn test_score_avx2() { let pbm = cm.to_probability([0.1, 0.1, 0.1, 0.1, 0.0]); let pwm = pbm.to_weight([0.25, 0.25, 0.25, 0.25, 0.0]); + striped.configure_wrap(pwm.data.rows() - 1); let pli = Pipeline::<_, __m256>::new(); let result = pli.score(&striped, &pwm); let scores = result.to_vec(); - - // assert_eq!(scores.len(), seq.len() - cm.len() + 1); - // assert_eq!(scores[0], -23.07094); // -23.07094 + assert_eq!(scores.len(), EXPECTED.len()); for i in 0..EXPECTED.len() { assert!( (scores[i] - EXPECTED[i]).abs() < 1e-5,