Skip to content
Snippets Groups Projects
Commit f7862877 authored by Martin Larralde's avatar Martin Larralde
Browse files

Add constructor for Python `ScoringMatrix` class

parent 093889bc
No related branches found
No related tags found
No related merge requests found
......@@ -436,6 +436,53 @@ pub struct ScoringMatrix {
#[pymethods]
impl ScoringMatrix {
/// Create a new scoring matrix.
#[new]
pub fn __init__(
_alphabet: &PyString,
values: &PyDict,
background: Option<PyObject>,
) -> PyResult<PyClassInitializer<Self>> {
// extract the background from the method argument
let bg = Python::with_gil(|py| {
if let Some(obj) = background {
if let Ok(d) = obj.extract::<&PyDict>(py) {
let p = dict_to_alphabet_array::<lightmotif::Dna>(d)?;
lightmotif::abc::Background::new(p)
.map_err(|_| PyValueError::new_err("Invalid background frequencies"))
} else {
Err(PyTypeError::new_err("Invalid type for pseudocount"))
}
} else {
Ok(lightmotif::abc::Background::uniform())
}
})?;
// build data
let mut data: Option<DenseMatrix<f32, <lightmotif::Dna as Alphabet>::K>> = None;
for s in lightmotif::Dna::symbols() {
let key = String::from(s.as_char());
if let Some(res) = values.get_item(&key) {
let column = res.downcast::<PyList>()?;
if data.is_none() {
data = Some(DenseMatrix::new(column.len()));
}
let matrix = data.as_mut().unwrap();
if matrix.rows() != column.len() {
return Err(PyValueError::new_err("Invalid number of rows"));
}
for (i, x) in column.iter().enumerate() {
matrix[i][s.as_index()] = x.extract::<f32>()?;
}
}
}
match data {
None => Err(PyValueError::new_err("Invalid count matrix")),
Some(matrix) => Ok(Self::from(lightmotif::ScoringMatrix::new(bg, matrix)).into()),
}
}
pub fn __eq__(&self, object: &PyAny) -> PyResult<bool> {
let py = object.py();
if let Ok(other) = object.extract::<Py<Self>>() {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment