Skip to content
Snippets Groups Projects
Commit 54b18e19 authored by Martin Larralde's avatar Martin Larralde
Browse files

Use raw pointer arithmetics in `__m256` implementation of the pipeline

parent 61cba1d2
No related branches found
No related tags found
No related merge requests found
use std::str::FromStr;
use nom::IResult;
use nom::error::Error;
use nom::error::ErrorKind;
use nom::multi::count;
use nom::multi::many_till;
use nom::combinator::eof;
use nom::combinator::map_res;
use nom::sequence::delimited;
use nom::bytes::complete::is_a;
use nom::bytes::complete::tag;
use nom::bytes::complete::take_till;
use nom::bytes::complete::take_while;
use nom::bytes::complete::take_while1;
use nom::combinator::eof;
use nom::combinator::map_res;
use nom::error::Error;
use nom::error::ErrorKind;
use nom::multi::count;
use nom::multi::many_till;
use nom::sequence::delimited;
use nom::IResult;
use crate::abc::Symbol;
use crate::abc::Alphabet;
use crate::pwm::CountMatrix;
use crate::abc::Symbol;
use crate::dense::DenseMatrix;
use crate::pwm::CountMatrix;
fn is_newline(c: char) -> bool {
c == '\r' || c == '\n'
......@@ -56,7 +55,7 @@ fn parse_symbol<S: Symbol>(input: &str) -> IResult<&str, S> {
if let Some(c) = input.chars().nth(0) {
match S::try_from(c) {
Ok(s) => Ok((&input[1..], s)),
Err(_) => Err(nom::Err::Failure(Error::new(input, ErrorKind::MapRes)))
Err(_) => Err(nom::Err::Failure(Error::new(input, ErrorKind::MapRes))),
}
} else {
Err(nom::Err::Error(Error::new(input, ErrorKind::Eof)))
......@@ -66,9 +65,9 @@ fn parse_symbol<S: Symbol>(input: &str) -> IResult<&str, S> {
fn parse_alphabet<S: Symbol>(input: &str) -> IResult<&str, Vec<S>> {
let (input, _) = tag("P0")(input)?;
let (input, _) = take_while1(is_space)(input)?;
let (input, ( symbols, _ )) = many_till(
let (input, (symbols, _)) = many_till(
delimited(take_while(is_space), parse_symbol, take_while(is_space)),
is_a("\n\r")
is_a("\n\r"),
)(input)?;
let (input, _) = take_while(is_newline)(input)?;
Ok((input, symbols))
......@@ -78,7 +77,11 @@ fn parse_row(input: &str, k: usize) -> IResult<&str, Vec<u32>> {
let (input, _) = take_while1(char::is_numeric)(input)?;
let (input, _) = take_while1(char::is_whitespace)(input)?;
let (input, counts) = count(
delimited(take_while(is_space), parse_integer::<u32>, take_while(is_space)),
delimited(
take_while(is_space),
parse_integer::<u32>,
take_while(is_space),
),
k,
)(input)?;
let (input, _) = take_till(is_newline)(input)?;
......@@ -96,7 +99,6 @@ pub fn parse_matrix<A: Alphabet, const K: usize>(input: &str) -> IResult<&str, C
let (input, _) = tag("//")(input)?;
let (input, _) = take_while1(char::is_whitespace)(input)?;
let mut data = DenseMatrix::<u32, K>::new(counts.len());
for (i, count) in counts.iter().enumerate() {
for (s, &c) in symbols.iter().zip(count.iter()) {
......@@ -108,18 +110,19 @@ pub fn parse_matrix<A: Alphabet, const K: usize>(input: &str) -> IResult<&str, C
Ok((input, matrix))
}
pub fn parse_matrices<A: Alphabet, const K: usize>(input: &str) -> IResult<&str, Vec<CountMatrix<A, K>>> {
pub fn parse_matrices<A: Alphabet, const K: usize>(
input: &str,
) -> IResult<&str, Vec<CountMatrix<A, K>>> {
let (input, (matrices, _)) = many_till(parse_matrix, eof)(input)?;
Ok((input, matrices))
}
#[cfg(test)]
mod test {
use crate::abc::DnaSymbol;
use crate::abc::DnaAlphabet;
use crate::abc::Alphabet;
use crate::abc::DnaAlphabet;
use crate::abc::DnaSymbol;
use crate::abc::Symbol;
#[test]
......@@ -143,7 +146,10 @@ mod test {
let line = "P0 A T G C\n";
let res = super::parse_alphabet::<DnaSymbol>(line).unwrap();
assert_eq!(res.0, "");
assert_eq!(res.1, vec![DnaSymbol::A, DnaSymbol::T, DnaSymbol::G, DnaSymbol::C]);
assert_eq!(
res.1,
vec![DnaSymbol::A, DnaSymbol::T, DnaSymbol::G, DnaSymbol::C]
);
let line = "P0 A T\n";
let res = super::parse_alphabet::<DnaSymbol>(line).unwrap();
......@@ -179,9 +185,9 @@ mod test {
"XX\n",
"//\n",
);
let res = super::parse_matrix::<DnaAlphabet, {DnaAlphabet::K}>(text).unwrap();
let res = super::parse_matrix::<DnaAlphabet, { DnaAlphabet::K }>(text).unwrap();
assert_eq!(res.0, "");
let matrix = res.1;
assert_eq!(matrix.name, "prodoric_MX000001");
assert_eq!(matrix.data.rows(), 7);
......@@ -196,4 +202,4 @@ mod test {
assert_eq!(matrix.data[5][DnaSymbol::C.as_index()], 1);
assert_eq!(matrix.data[5][DnaSymbol::N.as_index()], 0);
}
}
\ No newline at end of file
}
......@@ -60,11 +60,19 @@ impl Pipeline<DnaAlphabet, f32> {
impl Pipeline<DnaAlphabet, __m256> {
pub fn score(
&self,
seq: &StripedSequence<DnaAlphabet, 32>,
seq: &StripedSequence<DnaAlphabet, { std::mem::size_of::<__m256i>() }>,
pwm: &WeightMatrix<DnaAlphabet, { DnaAlphabet::K }>,
) -> DenseMatrix<f32, 32> {
const S: i32 = std::mem::size_of::<f32>() as i32;
const C: usize = std::mem::size_of::<__m256i>();
const K: usize = DnaAlphabet::K;
let mut result = DenseMatrix::new(seq.data.rows());
unsafe {
// get raw pointers to data
let sdata = seq.data[0].as_ptr();
let mdata = pwm.data[0].as_ptr();
let rdata: *mut f32 = result[0].as_mut_ptr();
// mask vectors for broadcasting:
let m1: __m256i = _mm256_set_epi32(
0xFFFFFF03u32 as i32,
......@@ -113,49 +121,14 @@ impl Pipeline<DnaAlphabet, __m256> {
let mut s3 = _mm256_setzero_ps();
let mut s4 = _mm256_setzero_ps();
// for j in 0..pwm.data.rows() {
// // load table
// let row = pwm.data[j].as_ptr();
// let c = _mm_load_ps(row);
// let t = _mm256_set_m128(c, c);
// // load text
// let x = _mm256_loadu_si256(seq.data[i+j].as_ptr() as *const __m256i);
// // compute probabilities using an external lookup table
// let p1 = _mm256_permutevar_ps(t, _mm256_shuffle_epi8(x, m1));
// let p2 = _mm256_permutevar_ps(t, _mm256_shuffle_epi8(x, m2));
// let p3 = _mm256_permutevar_ps(t, _mm256_shuffle_epi8(x, m3));
// let p4 = _mm256_permutevar_ps(t, _mm256_shuffle_epi8(x, m4));
// // add log odds
// s1 = _mm256_add_ps(s1, p1);
// s2 = _mm256_add_ps(s2, p2);
// s3 = _mm256_add_ps(s3, p3);
// s4 = _mm256_add_ps(s4, p4);
// }
for j in 0..pwm.data.rows() {
let x = _mm256_loadu_si256(seq.data[i + j].as_ptr() as *const __m256i);
let row = pwm.data[j].as_ptr();
let x = _mm256_loadu_si256(sdata.add((i+j)*C) as *const __m256i);
let row = mdata.add(j*K);
// compute probabilities using an external lookup table
let p1 = _mm256_i32gather_ps(
row,
_mm256_shuffle_epi8(x, m1),
std::mem::size_of::<f32>() as _,
);
let p2 = _mm256_i32gather_ps(
row,
_mm256_shuffle_epi8(x, m2),
std::mem::size_of::<f32>() as _,
);
let p3 = _mm256_i32gather_ps(
row,
_mm256_shuffle_epi8(x, m3),
std::mem::size_of::<f32>() as _,
);
let p4 = _mm256_i32gather_ps(
row,
_mm256_shuffle_epi8(x, m4),
std::mem::size_of::<f32>() as _,
);
let p1 = _mm256_i32gather_ps(row, _mm256_shuffle_epi8(x, m1), S);
let p2 = _mm256_i32gather_ps(row, _mm256_shuffle_epi8(x, m2), S);
let p3 = _mm256_i32gather_ps(row, _mm256_shuffle_epi8(x, m3), S);
let p4 = _mm256_i32gather_ps(row, _mm256_shuffle_epi8(x, m4), S);
// add log odds
s1 = _mm256_add_ps(s1, p1);
s2 = _mm256_add_ps(s2, p2);
......@@ -163,11 +136,11 @@ impl Pipeline<DnaAlphabet, __m256> {
s4 = _mm256_add_ps(s4, p4);
}
let row: &mut [f32] = &mut result[i];
_mm256_storeu_ps(&mut row[0x0], s1);
_mm256_storeu_ps(&mut row[0x4], s1);
_mm256_storeu_ps(&mut row[0x8], s1);
_mm256_storeu_ps(&mut row[0xc], s1);
let row = rdata.add(i);
_mm256_storeu_ps(row, s1);
_mm256_storeu_ps(row.add(0x4), s2);
_mm256_storeu_ps(row.add(0x8), s3);
_mm256_storeu_ps(row.add(0xc), s4);
}
}
result
......
......@@ -32,11 +32,11 @@ impl<A: Alphabet, const K: usize> From<f32> for Pseudocount<A, K> {
pub struct CountMatrix<A: Alphabet, const K: usize> {
pub alphabet: A,
pub data: DenseMatrix<u32, K>,
pub name: String, // FIXME: Use `Rc` instead to avoid copies.
pub name: String, // FIXME: Use `Rc` instead to avoid copies.
}
impl<A: Alphabet, const K: usize> CountMatrix<A, K> {
pub fn new<S>(name: S, data: DenseMatrix<u32, K>) -> Result<Self, ()>
pub fn new<S>(name: S, data: DenseMatrix<u32, K>) -> Result<Self, ()>
where
S: Into<String>,
{
......@@ -69,7 +69,7 @@ impl<A: Alphabet, const K: usize> CountMatrix<A, K> {
Ok(Self {
alphabet: A::default(),
data: data.unwrap_or_else(|| DenseMatrix::new(0)),
name: name.into()
name: name.into(),
})
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment