Skip to content
Snippets Groups Projects
Commit 9ede4ee1 authored by Martin Larralde's avatar Martin Larralde
Browse files

Remove `alphabet` parameter from `ScoringMatrix` Python class

parent d264b724
No related branches found
No related tags found
No related merge requests found
......@@ -15,7 +15,7 @@ AVX2_SUPPORTED: bool
FORMAT = Literal["jaspar", "jaspar16", "uniprobe", "transfac"]
class EncodedSequence:
def __init__(self, sequence: str, protein: bool = False) -> None: ...
def __init__(self, sequence: str, *, protein: bool = False) -> None: ...
def __str__(self) -> str: ...
def __len__(self) -> int: ...
def __copy__(self) -> EncodedSequence: ...
......@@ -34,7 +34,7 @@ class StripedSequence:
class CountMatrix:
def __init__(
self, values: Dict[str, Iterable[int]], protein: bool = False
self, values: Dict[str, Iterable[int]], *, protein: bool = False
) -> None: ...
def __eq__(self, other: object) -> bool: ...
def __len__(self) -> int: ...
......@@ -57,7 +57,11 @@ class WeightMatrix:
class ScoringMatrix:
def __init__(
self, values: Dict[str, Iterable[float]], protein: bool = False
self,
values: Dict[str, Iterable[float]],
background: Optional[Dict[str, float]] = None,
*,
protein: bool = False,
) -> None: ...
def __len__(self) -> int: ...
def __eq__(self, other: object) -> bool: ...
......@@ -125,16 +129,24 @@ class Hit:
M = typing.TypeVar("M", bound=Motif)
class Loader(Generic[M], Iterator[M]):
def __init__(
self,
file: Union[BinaryIO, PathLike[str]],
format: Literal["jaspar16"],
*,
protein: bool = False,
) -> None: ...
def __iter__(self) -> Loader[M]: ...
def __next__(self) -> M: ...
def create(
sequences: Iterable[str], protein: bool = False, name: Optional[str] = None
sequences: Iterable[str], *, protein: bool = False, name: Optional[str] = None
) -> Motif: ...
def stripe(sequence: str, protein: bool = False) -> StripedSequence: ...
def stripe(sequence: str, *, protein: bool = False) -> StripedSequence: ...
def scan(
pssm: ScoringMatrix,
sequence: StripedSequence,
*,
threshold: float = 0.0,
block_size: int = 256,
) -> Scanner: ...
......@@ -143,33 +155,33 @@ def load(
file: Union[BinaryIO, PathLike[str]],
format: Literal["jaspar"],
*,
protein: bool = False
protein: bool = False,
) -> Loader[JasparMotif]: ...
@typing.overload
def load(
file: Union[BinaryIO, PathLike[str]],
format: Literal["jaspar16"],
*,
protein: bool = False
protein: bool = False,
) -> Loader[JasparMotif]: ...
@typing.overload
def load(
file: Union[BinaryIO, PathLike[str]],
format: Literal["uniprobe"],
*,
protein: bool = False
protein: bool = False,
) -> Loader[UniprobeMotif]: ...
@typing.overload
def load(
file: Union[BinaryIO, PathLike[str]],
format: Literal["transfac"],
*,
protein: bool = False
protein: bool = False,
) -> Loader[TransfacMotif]: ...
@typing.overload
def load(
file: Union[BinaryIO, PathLike[str]],
format: FORMAT = "jaspar",
*,
protein: bool = False
protein: bool = False,
) -> Loader[Motif]: ...
......@@ -434,7 +434,7 @@ impl CountMatrix {
/// Create a new count matrix.
#[new]
#[allow(unused_variables)]
#[pyo3(signature = (values, protein = false))]
#[pyo3(signature = (values, *, protein = false))]
pub fn __init__<'py>(
values: Bound<'py, PyDict>,
protein: bool,
......@@ -746,50 +746,63 @@ impl ScoringMatrix {
impl ScoringMatrix {
/// Create a new scoring matrix.
#[new]
#[pyo3(signature = (alphabet, values, background = None))]
#[pyo3(signature = (values, background = None, *, protein = false))]
#[allow(unused)]
pub fn __init__<'py>(
alphabet: Bound<'py, PyString>,
values: Bound<'py, PyDict>,
background: Option<PyObject>,
protein: bool,
) -> PyResult<PyClassInitializer<Self>> {
// extract the background from the method argument
let bg = Python::with_gil(|py| {
if let Some(obj) = background {
if let Ok(d) = obj.extract::<Bound<PyDict>>(py) {
let p = dict_to_alphabet_array::<Dna>(d)?;
lightmotif::abc::Background::new(p)
.map_err(|_| PyValueError::new_err("Invalid background frequencies"))
} else {
Err(PyTypeError::new_err("Invalid type for pseudocount"))
}
} else {
Ok(lightmotif::abc::Background::uniform())
}
})?;
// build data
let mut data: Option<DenseMatrix<f32, <Dna as Alphabet>::K>> = None;
for s in Dna::symbols() {
let key = String::from(s.as_char());
if let Some(res) = values.get_item(&key)? {
let column = res.downcast::<PyList>()?;
if data.is_none() {
data = Some(DenseMatrix::new(column.len()));
}
let matrix = data.as_mut().unwrap();
if matrix.rows() != column.len() {
return Err(PyValueError::new_err("Invalid number of rows"));
macro_rules! run {
($alphabet:ty) => {{
// extract the background from the method argument
let bg = Python::with_gil(|py| {
if let Some(obj) = background {
if let Ok(d) = obj.extract::<Bound<PyDict>>(py) {
let p = dict_to_alphabet_array::<$alphabet>(d)?;
lightmotif::abc::Background::<$alphabet>::new(p).map_err(|_| {
PyValueError::new_err("Invalid background frequencies")
})
} else {
Err(PyTypeError::new_err("Invalid type for pseudocount"))
}
} else {
Ok(lightmotif::abc::Background::uniform())
}
})?;
// build data
let mut data: Option<DenseMatrix<f32, <$alphabet as Alphabet>::K>> = None;
for s in <$alphabet as Alphabet>::symbols() {
let key = String::from(s.as_char());
if let Some(res) = values.get_item(&key)? {
let column = res.downcast::<PyList>()?;
if data.is_none() {
data = Some(DenseMatrix::new(column.len()));
}
let matrix = data.as_mut().unwrap();
if matrix.rows() != column.len() {
return Err(PyValueError::new_err("Invalid number of rows"));
}
for (i, x) in column.iter().enumerate() {
matrix[i][s.as_index()] = x.extract::<f32>()?;
}
}
}
for (i, x) in column.iter().enumerate() {
matrix[i][s.as_index()] = x.extract::<f32>()?;
// create matrix
match data {
None => Err(PyValueError::new_err("Invalid count matrix")),
Some(matrix) => Ok(Self::new(lightmotif::ScoringMatrix::<$alphabet>::new(
bg, matrix,
))
.into()),
}
}
}};
}
match data {
None => Err(PyValueError::new_err("Invalid count matrix")),
Some(matrix) => Ok(Self::new(lightmotif::ScoringMatrix::<Dna>::new(bg, matrix)).into()),
if protein {
run!(Protein)
} else {
run!(Dna)
}
}
......@@ -1272,7 +1285,7 @@ impl From<lightmotif::scan::Hit> for Hit {
/// or when the sequence lengths are not consistent.
///
#[pyfunction]
#[pyo3(signature = (sequences, protein = false, name = None))]
#[pyo3(signature = (sequences, *, protein = false, name = None))]
pub fn create(sequences: Bound<PyAny>, protein: bool, name: Option<String>) -> PyResult<Motif> {
let py = sequences.py();
macro_rules! run {
......@@ -1323,7 +1336,7 @@ pub fn create(sequences: Bound<PyAny>, protein: bool, name: Option<String>) -> P
/// `ValueError`: When the sequences contains an invalid character.
///
#[pyfunction]
#[pyo3(signature = (sequence, protein=false))]
#[pyo3(signature = (sequence, *, protein=false))]
pub fn stripe(sequence: Bound<PyString>, protein: bool) -> PyResult<StripedSequence> {
let py = sequence.py();
let encoded = EncodedSequence::__init__(sequence, protein).and_then(|e| Py::new(py, e))?;
......@@ -1359,7 +1372,7 @@ pub fn stripe(sequence: Bound<PyString>, protein: bool) -> PyResult<StripedSeque
/// implementation requirements.
///
#[pyfunction]
#[pyo3(signature = (pssm, sequence, threshold = 0.0, block_size = 256))]
#[pyo3(signature = (pssm, sequence, *, threshold = 0.0, block_size = 256))]
pub fn scan<'py>(
pssm: Bound<'py, ScoringMatrix>,
sequence: Bound<'py, StripedSequence>,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment