From cd4605726301d7bf5926db8a779a3f2595de298b Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Fri, 30 Aug 2024 21:34:04 +0200
Subject: [PATCH] Add `Motif.name` property to store the name of a motif in
 `lightmotif-py`

---
 lightmotif-py/lightmotif/io.rs  | 16 +++++++++++++---
 lightmotif-py/lightmotif/lib.rs |  9 +++++++--
 2 files changed, 20 insertions(+), 5 deletions(-)

diff --git a/lightmotif-py/lightmotif/io.rs b/lightmotif-py/lightmotif/io.rs
index 8976169..99b6c6d 100644
--- a/lightmotif-py/lightmotif/io.rs
+++ b/lightmotif-py/lightmotif/io.rs
@@ -25,8 +25,12 @@ where
     ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
 {
     let record = record.unwrap();
+    let name = record.id().to_string();
     let counts = record.into_matrix();
-    Python::with_gil(|py| Motif::from_counts(py, counts))
+
+    let mut motif = Python::with_gil(|py| Motif::from_counts(py, counts))?;
+    motif.name = Some(name);
+    Ok(motif)
 }
 
 fn convert_transfac<A>(record: Result<lightmotif_io::transfac::Record<A>, Error>) -> PyResult<Motif>
@@ -37,10 +41,13 @@ where
     ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
 {
     let record = record.unwrap();
+    let name = record.accession().or(record.id()).map(String::from);
     let counts = record
         .to_counts()
         .ok_or_else(|| PyValueError::new_err("invalid count matrix"))?;
-    Python::with_gil(|py| Motif::from_counts(py, counts))
+    let mut motif = Python::with_gil(|py| Motif::from_counts(py, counts))?;
+    motif.name = name;
+    Ok(motif)
 }
 
 fn convert_uniprobe<A>(record: Result<lightmotif_io::uniprobe::Record<A>, Error>) -> PyResult<Motif>
@@ -50,9 +57,12 @@ where
     ScoringMatrixData: From<lightmotif::pwm::ScoringMatrix<A>>,
 {
     let record = record.unwrap();
+    let name = record.id().to_string();
     let freqs = record.into_matrix();
     let weights = freqs.to_weight(None);
-    Python::with_gil(|py| Motif::from_weights(py, weights))
+    let mut motif = Python::with_gil(|py| Motif::from_weights(py, weights))?;
+    motif.name = Some(name);
+    Ok(motif)
 }
 
 #[pyclass(module = "lightmotif.lib")]
diff --git a/lightmotif-py/lightmotif/lib.rs b/lightmotif-py/lightmotif/lib.rs
index 79bc65f..8d6c723 100644
--- a/lightmotif-py/lightmotif/lib.rs
+++ b/lightmotif-py/lightmotif/lib.rs
@@ -1049,6 +1049,8 @@ pub struct Motif {
     pwm: Py<WeightMatrix>,
     #[pyo3(get)]
     pssm: Py<ScoringMatrix>,
+    #[pyo3(get)]
+    name: Option<String>,
 }
 
 impl Motif {
@@ -1065,6 +1067,7 @@ impl Motif {
             counts: Some(Py::new(py, CountMatrix::new(counts))?),
             pwm: Py::new(py, WeightMatrix::new(weights))?,
             pssm: Py::new(py, ScoringMatrix::new(scoring))?,
+            name: None,
         })
     }
 
@@ -1079,6 +1082,7 @@ impl Motif {
             counts: None,
             pwm: Py::new(py, WeightMatrix::new(weights))?,
             pssm: Py::new(py, ScoringMatrix::new(scoring))?,
+            name: None,
         })
     }
 }
@@ -1154,8 +1158,8 @@ impl From<lightmotif::scan::Hit> for Hit {
 ///         or when the sequence lengths are not consistent.
 ///
 #[pyfunction]
-#[pyo3(signature = (sequences, protein = false))]
-pub fn create(sequences: Bound<PyAny>, protein: bool) -> PyResult<Motif> {
+#[pyo3(signature = (sequences, protein = false, name = None))]
+pub fn create(sequences: Bound<PyAny>, protein: bool, name: Option<String>) -> PyResult<Motif> {
     let py = sequences.py();
     macro_rules! run {
         ($alphabet:ty) => {{
@@ -1178,6 +1182,7 @@ pub fn create(sequences: Bound<PyAny>, protein: bool) -> PyResult<Motif> {
                 counts: Some(Py::new(py, CountMatrix::new(data))?),
                 pwm: Py::new(py, WeightMatrix::new(weights))?,
                 pssm: Py::new(py, ScoringMatrix::new(scoring))?,
+                name,
             })
         }};
     }
-- 
GitLab