From bbde5508117eb0ee39026993e4c372ed835ba470 Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Thu, 20 Jun 2024 17:14:35 +0200 Subject: [PATCH] Reorganize tests in `lightmotif` crate --- lightmotif/tests/dna.rs | 206 +++++++++++++++++++++++-------------- lightmotif/tests/encode.rs | 95 +++++++++++++---- 2 files changed, 200 insertions(+), 101 deletions(-) diff --git a/lightmotif/tests/dna.rs b/lightmotif/tests/dna.rs index a6aaabf..989dec1 100644 --- a/lightmotif/tests/dna.rs +++ b/lightmotif/tests/dna.rs @@ -89,6 +89,35 @@ fn test_score<C: StrictlyPositive, P: Score<f32, Dna, C>>(pli: &P) { } } +fn test_score_discrete<C: StrictlyPositive, P: Score<u8, Dna, C>>(pli: &P) { + let encoded = EncodedSequence::<Dna>::encode(SEQUENCE).unwrap(); + let mut striped = Pipeline::generic().stripe(encoded); + + let cm = CountMatrix::<Dna>::from_sequences( + PATTERNS.iter().map(|x| EncodedSequence::encode(x).unwrap()), + ) + .unwrap(); + let pbm = cm.to_freq(0.1); + let pwm = pbm.to_weight(None); + let pssm = pwm.to_scoring(); + let dm = pssm.to_discrete(); + + striped.configure(&pssm); + let result = pli.score(&dm, &striped); + let scores = result.unstripe(); + + assert_eq!(scores.len(), EXPECTED.len()); + for i in 0..scores.len() { + assert!( + dm.unscale(scores[i]) >= EXPECTED[i], + "{} != {} at position {}", + dm.unscale(scores[i]), + EXPECTED[i], + i + ); + } +} + fn test_argmax<C: StrictlyPositive, P: Score<f32, Dna, C> + Maximum<f32, C>>(pli: &P) { let encoded = EncodedSequence::<Dna>::encode(SEQUENCE).unwrap(); let mut striped = Pipeline::generic().stripe(encoded); @@ -197,104 +226,121 @@ mod generic { } } -#[test] -fn test_score_dispatch() { - let pli = Pipeline::dispatch(); - test_score(&pli); -} +mod dispatch { + use super::*; -#[test] -fn test_argmax_dispatch() { - let pli = Pipeline::dispatch(); - test_argmax(&pli); -} + #[test] + fn test_score() { + let pli = Pipeline::dispatch(); + super::test_score::<U32, _>(&pli); + } -#[test] -fn test_threshold_dispatch() { - let pli = Pipeline::dispatch(); - test_threshold(&pli); -} + #[test] + fn test_score_rows() { + let pli = Pipeline::dispatch(); + super::test_score_rows::<U32, _>(&pli); + } -#[cfg(target_feature = "sse2")] -#[test] -fn test_score_sse2() { - let pli = Pipeline::sse2().unwrap(); - test_score::<U16, _>(&pli); -} + #[test] + fn test_argmax() { + let pli = Pipeline::dispatch(); + super::test_argmax::<U32, _>(&pli); + } -#[cfg(target_feature = "sse2")] -#[test] -fn test_argmax_sse2() { - let pli = Pipeline::sse2().unwrap(); - test_argmax::<U16, _>(&pli); + #[test] + fn test_threshold() { + let pli = Pipeline::dispatch(); + super::test_threshold::<U32, _>(&pli); + } } #[cfg(target_feature = "sse2")] -#[test] -fn test_threshold_sse2() { - let pli = Pipeline::sse2().unwrap(); - test_threshold::<U16, _>(&pli); -} +mod sse2 { + use super::*; -#[cfg(target_feature = "sse2")] -#[test] -fn test_score_sse2_32() { - let pli = Pipeline::sse2().unwrap(); - test_score::<U32, _>(&pli); -} + #[test] + fn test_score() { + let pli = Pipeline::sse2().unwrap(); + super::test_score::<U32, _>(&pli); + super::test_score::<U16, _>(&pli); + } -#[cfg(target_feature = "sse2")] -#[test] -fn test_argmax_sse2_32() { - let pli = Pipeline::sse2().unwrap(); - test_argmax::<U32, _>(&pli); -} + #[test] + fn test_score_rows() { + let pli = Pipeline::sse2().unwrap(); + super::test_score_rows::<U32, _>(&pli); + super::test_score_rows::<U16, _>(&pli); + } -#[cfg(target_feature = "sse2")] -#[test] -fn test_threshold_sse2_32() { - let pli = Pipeline::sse2().unwrap(); - test_threshold::<U32, _>(&pli); -} + #[test] + fn test_argmax() { + let pli = Pipeline::sse2().unwrap(); + super::test_argmax::<U32, _>(&pli); + super::test_argmax::<U16, _>(&pli); + } -#[cfg(target_feature = "avx2")] -#[test] -fn test_score_avx2() { - let pli = Pipeline::avx2().unwrap(); - test_score(&pli); + #[test] + fn test_threshold() { + let pli = Pipeline::sse2().unwrap(); + super::test_threshold::<U32, _>(&pli); + super::test_threshold::<U16, _>(&pli); + } } #[cfg(target_feature = "avx2")] -#[test] -fn test_argmax_avx2() { - let pli = Pipeline::avx2().unwrap(); - test_argmax(&pli); -} +mod avx2 { + use super::*; -#[cfg(target_feature = "avx2")] -#[test] -fn test_threshold_avx2() { - let pli = Pipeline::avx2().unwrap(); - test_threshold::<U32, _>(&pli); -} + #[test] + fn test_score() { + let pli = Pipeline::avx2().unwrap(); + super::test_score::<U32, _>(&pli); + } -#[cfg(target_feature = "neon")] -#[test] -fn test_score_neon() { - let pli = Pipeline::neon().unwrap(); - test_score::<U16, _>(&pli); -} + #[test] + fn test_score_rows() { + let pli = Pipeline::avx2().unwrap(); + super::test_score_rows::<U32, _>(&pli); + } -#[cfg(target_feature = "neon")] -#[test] -fn test_argmax_neon() { - let pli = Pipeline::neon().unwrap(); - test_argmax::<U16, _>(&pli); + #[test] + fn test_argmax() { + let pli = Pipeline::avx2().unwrap(); + super::test_argmax::<U32, _>(&pli); + } + + #[test] + fn test_threshold() { + let pli = Pipeline::avx2().unwrap(); + super::test_threshold::<U32, _>(&pli); + } } #[cfg(target_feature = "neon")] -#[test] -fn test_threshold_neon() { - let pli = Pipeline::neon().unwrap(); - test_threshold::<U16, _>(&pli); +mod neon { + use super::*; + + #[test] + fn test_score() { + let pli = Pipeline::neon().unwrap(); + super::test_score::<U16, _>(&pli); + } + + #[test] + fn test_score_rows() { + let pli = Pipeline::neon().unwrap(); + super::test_score_rows::<U16, _>(&pli); + } + + #[test] + fn test_argmax() { + let pli = Pipeline::neon().unwrap(); + super::test_argmax::<U16, _>(&pli); + } + + #[test] + fn test_threshold() { + let pli = Pipeline::neon().unwrap(); + super::test_threshold::<U16, _>(&pli); + } } diff --git a/lightmotif/tests/encode.rs b/lightmotif/tests/encode.rs index 0356bcc..8376df4 100644 --- a/lightmotif/tests/encode.rs +++ b/lightmotif/tests/encode.rs @@ -15,42 +15,95 @@ const EXPECTED: &[Nucleotide] = &[ A, T, G, T, C, C, C, A, A, C, A, A, C, G, A, T, A, C, C, C, C, G, A, G, C, C, C, A, T, C, G, C, C, G, T, C, A, T, C, G, G, C, T, C, G, G, C, A, T, G, C, A, G, A, T, T, C, C, C, A, G, G, C, G ]; -fn test_encode<P: Encode<Dna>>(pli: &P) { +fn test_encode_sequence<P: Encode<Dna>>(pli: &P) { let encoded = pli.encode(SEQUENCE).unwrap(); assert_eq!(encoded, EXPECTED); +} + +fn test_encode_unknown<P: Encode<Dna>>(pli: &P) { let err = pli.encode(UNKNOWNS).unwrap_err(); assert_eq!(err.0, '.'); } -#[test] -fn test_encode_generic() { - let pli = Pipeline::generic(); - test_encode(&pli); +mod generic { + use super::*; + + #[test] + fn test_sequence() { + let pli = Pipeline::generic(); + test_encode_sequence(&pli); + } + + #[test] + fn test_unknown() { + let pli = Pipeline::generic(); + test_encode_unknown(&pli); + } } -#[test] -fn test_encode_dispatch() { - let pli = Pipeline::dispatch(); - test_encode(&pli); +mod dispatch { + use super::*; + + #[test] + fn test_sequence() { + let pli = Pipeline::dispatch(); + test_encode_sequence(&pli); + } + + #[test] + fn test_unknown() { + let pli = Pipeline::dispatch(); + test_encode_unknown(&pli); + } } #[cfg(target_feature = "sse2")] -#[test] -fn test_encode_sse2() { - let pli = Pipeline::sse2().unwrap(); - test_encode(&pli); +mod sse2 { + use super::*; + + #[test] + fn test_sequence() { + let pli = Pipeline::sse2().unwrap(); + test_encode_sequence(&pli); + } + + #[test] + fn test_unknown() { + let pli = Pipeline::sse2().unwrap(); + test_encode_unknown(&pli); + } } #[cfg(target_feature = "avx2")] -#[test] -fn test_encode_avx2() { - let pli = Pipeline::avx2().unwrap(); - test_encode(&pli); +mod avx2 { + use super::*; + + #[test] + fn test_sequence() { + let pli = Pipeline::avx2().unwrap(); + test_encode_sequence(&pli); + } + + #[test] + fn test_unknown() { + let pli = Pipeline::avx2().unwrap(); + test_encode_unknown(&pli); + } } #[cfg(target_feature = "neon")] -#[test] -fn test_encode_neon() { - let pli = Pipeline::neon().unwrap(); - test_encode(&pli); +mod neon { + use super::*; + + #[test] + fn test_sequence() { + let pli = Pipeline::neon().unwrap(); + test_encode_sequence(&pli); + } + + #[test] + fn test_unknown() { + let pli = Pipeline::neon().unwrap(); + test_encode_unknown(&pli); + } } -- GitLab