Skip to content
Snippets Groups Projects
Commit 18508e32 authored by Martin Larralde's avatar Martin Larralde
Browse files

Rewrite `__m256` implementation using permute instead of gather instructions

parent e3461174
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@ use std::arch::x86_64::*;
use self::seal::Vector;
use super::abc::Alphabet;
use super::abc::DnaAlphabet;
use super::abc::DnaSymbol;
use super::abc::Symbol;
use super::dense::DenseMatrix;
use super::pwm::WeightMatrix;
......@@ -78,10 +79,6 @@ impl Pipeline<DnaAlphabet, __m256> {
pwm: &WeightMatrix<DnaAlphabet, { DnaAlphabet::K }>,
scores: &mut StripedScores<__m256, { std::mem::size_of::<__m256i>() }>,
) {
const S: i32 = std::mem::size_of::<f32>() as i32;
const C: usize = std::mem::size_of::<__m256i>();
const K: usize = DnaAlphabet::K;
let mut result = &mut scores.data;
scores.length = seq.length - pwm.len() + 1;
......@@ -93,9 +90,10 @@ impl Pipeline<DnaAlphabet, __m256> {
}
unsafe {
// get raw pointers to data
// mask vectors for broadcasting:
let m1: __m256i = _mm256_set_epi32(
// constant vector for comparing unknown bases
let n = _mm256_set1_epi8(DnaSymbol::N as i8);
// mask vectors for broadcasting uint8x32_t to uint32x8_t to floatx8_t
let m1 = _mm256_set_epi32(
0xFFFFFF03u32 as i32,
0xFFFFFF02u32 as i32,
0xFFFFFF01u32 as i32,
......@@ -105,7 +103,7 @@ impl Pipeline<DnaAlphabet, __m256> {
0xFFFFFF01u32 as i32,
0xFFFFFF00u32 as i32,
);
let m2: __m256i = _mm256_set_epi32(
let m2 = _mm256_set_epi32(
0xFFFFFF07u32 as i32,
0xFFFFFF06u32 as i32,
0xFFFFFF05u32 as i32,
......@@ -115,7 +113,7 @@ impl Pipeline<DnaAlphabet, __m256> {
0xFFFFFF05u32 as i32,
0xFFFFFF04u32 as i32,
);
let m3: __m256i = _mm256_set_epi32(
let m3 = _mm256_set_epi32(
0xFFFFFF0Bu32 as i32,
0xFFFFFF0Au32 as i32,
0xFFFFFF09u32 as i32,
......@@ -125,7 +123,7 @@ impl Pipeline<DnaAlphabet, __m256> {
0xFFFFFF09u32 as i32,
0xFFFFFF08u32 as i32,
);
let m4: __m256i = _mm256_set_epi32(
let m4 = _mm256_set_epi32(
0xFFFFFF0Fu32 as i32,
0xFFFFFF0Eu32 as i32,
0xFFFFFF0Du32 as i32,
......@@ -135,26 +133,49 @@ impl Pipeline<DnaAlphabet, __m256> {
0xFFFFFF0Du32 as i32,
0xFFFFFF0Cu32 as i32,
);
// loop over every row of the sequence data
// process every position of the sequence data
for i in 0..seq.data.rows() - seq.wrap {
// reset sums for current position
let mut s1 = _mm256_setzero_ps();
let mut s2 = _mm256_setzero_ps();
let mut s3 = _mm256_setzero_ps();
let mut s4 = _mm256_setzero_ps();
// advance position in the position weight matrix
for j in 0..pwm.len() {
// load sequence row and broadcast to f32
let x = _mm256_load_si256(seq.data[i + j].as_ptr() as *const __m256i);
let x1 = _mm256_shuffle_epi8(x, m1);
let x2 = _mm256_shuffle_epi8(x, m2);
let x3 = _mm256_shuffle_epi8(x, m3);
let x4 = _mm256_shuffle_epi8(x, m4);
// load row for current weight matrix position
let row = pwm.data[j].as_ptr();
// compute probabilities using an external lookup table
let p1 = _mm256_i32gather_ps(row, _mm256_shuffle_epi8(x, m1), S);
let p2 = _mm256_i32gather_ps(row, _mm256_shuffle_epi8(x, m2), S);
let p3 = _mm256_i32gather_ps(row, _mm256_shuffle_epi8(x, m3), S);
let p4 = _mm256_i32gather_ps(row, _mm256_shuffle_epi8(x, m4), S);
// add log odds
s1 = _mm256_add_ps(s1, p1);
s2 = _mm256_add_ps(s2, p2);
s3 = _mm256_add_ps(s3, p3);
s4 = _mm256_add_ps(s4, p4);
let c = _mm_load_ps(row);
let t = _mm256_set_m128(c, c);
let u = _mm256_set1_ps(*row.add(DnaSymbol::N.as_index()));
// check which bases from the sequence are unknown
let mask = _mm256_cmpeq_epi8(x, n);
let unk1 = _mm256_castsi256_ps(_mm256_shuffle_epi8(mask, m1));
let unk2 = _mm256_castsi256_ps(_mm256_shuffle_epi8(mask, m2));
let unk3 = _mm256_castsi256_ps(_mm256_shuffle_epi8(mask, m3));
let unk4 = _mm256_castsi256_ps(_mm256_shuffle_epi8(mask, m4));
// index A/T/G/C lookup table with the bases
let p1 = _mm256_permutevar_ps(t, x1);
let p2 = _mm256_permutevar_ps(t, x2);
let p3 = _mm256_permutevar_ps(t, x3);
let p4 = _mm256_permutevar_ps(t, x4);
// blend together known and unknown scores
let b1 = _mm256_blendv_ps(p1, u, unk1);
let b2 = _mm256_blendv_ps(p2, u, unk2);
let b3 = _mm256_blendv_ps(p3, u, unk3);
let b4 = _mm256_blendv_ps(p4, u, unk4);
// add log odds to the running sum
s1 = _mm256_add_ps(s1, b1);
s2 = _mm256_add_ps(s2, b2);
s3 = _mm256_add_ps(s3, b3);
s4 = _mm256_add_ps(s4, b4);
}
// record the score for the current position
let row = &mut result[i];
_mm256_store_ps(row[0..].as_mut_ptr(), s1);
_mm256_store_ps(row[8..].as_mut_ptr(), s2);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment