From 8d763605a34e20aeb87114f8aeadc6024f4cd285 Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Wed, 30 Aug 2023 22:30:33 +0200 Subject: [PATCH] Make `StripedSequence::new` return an error when given an invalid length --- lightmotif/src/pli/mod.rs | 4 ++-- lightmotif/src/seq.rs | 17 +++++++++++------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/lightmotif/src/pli/mod.rs b/lightmotif/src/pli/mod.rs index f124ce6..d7b0872 100644 --- a/lightmotif/src/pli/mod.rs +++ b/lightmotif/src/pli/mod.rs @@ -142,7 +142,7 @@ pub trait Stripe<A: Alphabet, C: StrictlyPositive> { let s = seq.as_ref(); let length = s.len(); let rows = (length / C::USIZE) + ((length % C::USIZE > 0) as usize); - let mut striped = StripedSequence::new(DenseMatrix::new(rows), length); + let mut striped = StripedSequence::new(DenseMatrix::new(rows), length).unwrap(); self.stripe_into(s, &mut striped); striped } @@ -159,7 +159,7 @@ pub trait Stripe<A: Alphabet, C: StrictlyPositive> { for (i, &x) in s.iter().enumerate() { data[i % rows][i / rows] = x; } - for i in s.len()..matrix.rows() * matrix.columns() { + for i in s.len()..data.rows() * data.columns() { data[i % rows][i / rows] = A::default_symbol(); } } diff --git a/lightmotif/src/seq.rs b/lightmotif/src/seq.rs index 3e48a88..12abfc2 100644 --- a/lightmotif/src/seq.rs +++ b/lightmotif/src/seq.rs @@ -11,6 +11,7 @@ use typenum::marker_traits::Unsigned; use super::abc::Alphabet; use super::abc::Symbol; use super::dense::DenseMatrix; +use super::err::InvalidData; use super::err::InvalidSymbol; use super::num::StrictlyPositive; use super::pwm::ScoringMatrix; @@ -145,12 +146,16 @@ pub struct StripedSequence<A: Alphabet, C: StrictlyPositive> { impl<A: Alphabet, C: StrictlyPositive> StripedSequence<A, C> { /// Create a new striped sequence from the given dense matrix. - pub fn new(data: DenseMatrix<A::Symbol, C>, length: usize) -> Self { - Self { - data, - length, - wrap: 0, - alphabet: std::marker::PhantomData, + pub fn new(data: DenseMatrix<A::Symbol, C>, length: usize) -> Result<Self, InvalidData> { + if data.rows() * data.columns() < length { + Err(InvalidData) + } else { + Ok(Self { + data, + length, + wrap: 0, + alphabet: std::marker::PhantomData, + }) } } -- GitLab