From cedd54be14bcd6fad260917a1b24501587006c1d Mon Sep 17 00:00:00 2001
From: Peter van 't Hof <p.j.van_t_hof@lumc.nl>
Date: Mon, 24 Aug 2015 17:26:39 +0200
Subject: [PATCH] Refactor RegionAfCount, 10-20 times faster now

---
 .../sasc/biopet/tools/RegionAfCount.scala     | 120 +++++++++++-------
 .../biopet/utils/intervals/BedRecord.scala    |  26 ++--
 .../utils/intervals/BedRecordList.scala       |  22 ++--
 3 files changed, 101 insertions(+), 67 deletions(-)

diff --git a/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/RegionAfCount.scala b/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/RegionAfCount.scala
index 81e9283e0..685a9b5ec 100644
--- a/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/RegionAfCount.scala
+++ b/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/RegionAfCount.scala
@@ -38,19 +38,19 @@ import scala.math._
 
 object RegionAfCount extends ToolCommand {
   case class Args(bedFile: File = null,
-                  outputFile: File = null,
-                  scatterpPlot: Option[File] = None,
+                  outputPrefix: String = null,
+                  scatterpPlot: Boolean = false,
                   vcfFiles: List[File] = Nil) extends AbstractArgs
 
   class OptParser extends AbstractOptParser {
-    opt[File]('b', "bedFile") required () maxOccurs 1 valueName "<file>" action { (x, c) =>
+    opt[File]('b', "bedFile") unbounded () required () maxOccurs 1 valueName "<file>" action { (x, c) =>
       c.copy(bedFile = x)
     }
-    opt[File]('o', "outputFile") required () maxOccurs 1 valueName "<file>" action { (x, c) =>
-      c.copy(outputFile = x)
+    opt[String]('o', "outputPrefix") unbounded () required () maxOccurs 1 valueName "<file prefix>" action { (x, c) =>
+      c.copy(outputPrefix = x)
     }
-    opt[File]('s', "scatterPlot") maxOccurs 1 valueName "<file>" action { (x, c) =>
-      c.copy(scatterpPlot = Some(x))
+    opt[Unit]('s', "scatterPlot") unbounded () action { (x, c) =>
+      c.copy(scatterpPlot = true)
     }
     opt[File]('V', "vcfFile") unbounded () minOccurs 1 action { (x, c) =>
       c.copy(vcfFiles = c.vcfFiles ::: x :: Nil)
@@ -71,68 +71,94 @@ object RegionAfCount extends ToolCommand {
     val combinedBedRecords = bedRecords.combineOverlap
 
     logger.info(s"${combinedBedRecords.allRecords.size} left")
-
+    logger.info(s"${combinedBedRecords.allRecords.size * cmdArgs.vcfFiles.size} query's to do")
     logger.info("Reading vcf files")
 
+    case class AfCounts(var names: Double = 0,
+                        var namesExons: Double = 0,
+                        var namesIntrons: Double = 0,
+                        var namesCoding: Double = 0,
+                        var utr: Double = 0,
+                        var utr5: Double = 0,
+                        var utr3: Double = 0,
+                        var exons: Map[String, Double] = Map(),
+                        var intron: Map[String, Double] = Map())
+
     var c = 0
-    val afCountsRaw = for (region <- combinedBedRecords.allRecords.par) yield {
-      val sum = (for (vcfFile <- cmdArgs.vcfFiles.par) yield vcfFile -> {
-        val afCounts = mutable.Map[String, Double]()
-        val reader = new VCFFileReader(vcfFile, true)
-        val it = reader.query(region.chr, region.start, region.end)
-        for (variant <- it) {
+
+    val afCounts = (for (vcfFile <- cmdArgs.vcfFiles.par) yield vcfFile -> {
+      val reader = new VCFFileReader(vcfFile, true)
+      val afCounts: mutable.Map[String, AfCounts] = mutable.Map()
+      for (region <- combinedBedRecords.allRecords) yield {
+        val originals = region.originals()
+        for (variant <- reader.query(region.chr, region.start, region.end)) {
           val sum = (variant.getAttribute("AF", 0) match {
             case a: util.ArrayList[_] => a.map(_.toString.toDouble).toArray
             case s                    => Array(s.toString.toDouble)
           }).sum
-          region.originals()
-            .map(x => x.name.getOrElse(s"${x.chr}:${x.start}-${x.end}"))
-            .distinct
-            .foreach(name => afCounts += name -> (afCounts.getOrElse(name, 0.0) + sum))
+          val interval = BedRecord(variant.getContig, variant.getStart, variant.getEnd)
+          originals.foreach { x =>
+            val name = x.name.getOrElse(s"${x.chr}:${x.start}-${x.end}")
+            if (!afCounts.contains(name)) afCounts += name -> AfCounts()
+            afCounts(name).names += sum
+            val exons = x.exons.getOrElse(Seq()).filter(_.overlapWith(interval))
+            val introns = x.introns.getOrElse(Seq()).filter(_.overlapWith(interval))
+            val utr5 = x.utr5.map(_.overlapWith(interval))
+            val utr3 = x.utr3.map(_.overlapWith(interval))
+            if (exons.nonEmpty) {
+              afCounts(name).namesExons += sum
+              if (!utr5.getOrElse(false) && !utr3.getOrElse(false)) afCounts(name).namesCoding += sum
+            }
+            if (introns.nonEmpty) afCounts(name).namesIntrons += sum
+            if (utr5.getOrElse(false) || utr3.getOrElse(false)) afCounts(name).utr += sum
+            if (utr5.getOrElse(false)) afCounts(name).utr5 += sum
+            if (utr3.getOrElse(false)) afCounts(name).utr3 += sum
+          }
         }
-        reader.close()
-        afCounts.toMap
-      }).toMap
-
-      c += 1
-      if (c % 100 == 0) logger.info(s"$c regions done")
+        c += 1
+        if (c % 100 == 0) logger.info(s"$c regions done")
+      }
+      afCounts.toMap
+    }).toMap
 
-      sum
-    }
+    logger.info(s"Done reading, ${c} regions")
 
-    logger.info(s"Done reading, $c regions")
+    logger.info("Writing output files")
 
-    val afCounts: Map[String, Map[File, Double]] = {
-      val combinedAfCounts: mutable.Map[String, mutable.Map[File, Double]] = mutable.Map()
-      for (x <- afCountsRaw.toList; (file, counts) <- x.toList; (name, count) <- counts) {
-        val map = combinedAfCounts.getOrElse(name, mutable.Map())
-        map += file -> (map.getOrElse(file, 0.0) + count)
-        combinedAfCounts += name -> map
+    def writeOutput(tsvFile: File, function: AfCounts => Double): Unit = {
+      val writer = new PrintWriter(tsvFile)
+      writer.println("\t" + cmdArgs.vcfFiles.map(_.getName).mkString("\t"))
+      for (r <- cmdArgs.vcfFiles.foldLeft(Set[String]())((a, b) => a ++ afCounts(b).keySet)) {
+        writer.print(r + "\t")
+        writer.println(cmdArgs.vcfFiles.map(x => function(afCounts(x).getOrElse(r, AfCounts()))).mkString("\t"))
       }
-      combinedAfCounts.map(x => x._1 -> x._2.toMap).toMap
-    }
-
-    logger.info("Writing output file")
+      writer.close()
 
-    val writer = new PrintWriter(cmdArgs.outputFile)
-    writer.println("\t" + cmdArgs.vcfFiles.map(_.getName).mkString("\t"))
-    for (r <- afCounts.keys) {
-      writer.print(r + "\t")
-      writer.println(cmdArgs.vcfFiles.map(afCounts(r).getOrElse(_, 0.0)).mkString("\t"))
+      if (cmdArgs.scatterpPlot) generatePlot(tsvFile)
     }
-    writer.close()
 
-    cmdArgs.scatterpPlot.foreach { scatterPlotFile =>
-      logger.info("Generate plot")
+    def generatePlot(tsvFile: File): Unit = {
+      logger.info(s"Generate plot for $tsvFile")
 
       val scatterPlot = new ScatterPlot(null)
-      scatterPlot.input = cmdArgs.outputFile
-      scatterPlot.output = scatterPlotFile
+      scatterPlot.input = tsvFile
+      scatterPlot.output = new File(tsvFile.getAbsolutePath.stripSuffix(".tsv") + ".png")
       scatterPlot.ylabel = Some("Sum of AFs")
       scatterPlot.width = Some(1200)
       scatterPlot.height = Some(1000)
       scatterPlot.runLocal()
     }
+    for (
+      arg <- List[(File, AfCounts => Double)](
+        (new File(cmdArgs.outputPrefix + ".names.tsv"), _.names),
+        (new File(cmdArgs.outputPrefix + ".names.exons_only.tsv"), _.namesExons),
+        (new File(cmdArgs.outputPrefix + ".names.introns_only.tsv"), _.namesIntrons),
+        (new File(cmdArgs.outputPrefix + ".names.coding.tsv"), _.namesCoding),
+        (new File(cmdArgs.outputPrefix + ".names.utr.tsv"), _.utr),
+        (new File(cmdArgs.outputPrefix + ".names.utr5.tsv"), _.utr5),
+        (new File(cmdArgs.outputPrefix + ".names.utr3.tsv"), _.utr3)
+      ).par
+    ) writeOutput(arg._1, arg._2)
 
     logger.info("Done")
   }
diff --git a/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/utils/intervals/BedRecord.scala b/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/utils/intervals/BedRecord.scala
index a7222a93b..b4dd0135c 100644
--- a/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/utils/intervals/BedRecord.scala
+++ b/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/utils/intervals/BedRecord.scala
@@ -23,27 +23,33 @@ case class BedRecord(chr: String,
     else _originals
   }
 
+  def overlapWith(record: BedRecord): Boolean = {
+    if (chr != record.chr) false
+    else if (start <= record.end && record.start <= end) true
+    else false
+  }
+
   def length = end - start + 1
 
   lazy val exons = if (blockCount.isDefined && blockSizes.length > 0 && blockStarts.length > 0) {
-    Some(for (i <- 0 to blockCount.get) yield {
+    Some(for (i <- 0 until blockCount.get) yield {
       val exonNumber = strand match {
         case Some(false) => blockCount.get - i
         case _           => i + 1
       }
       BedRecord(chr, start + blockStarts(i), start + blockStarts(i) + blockSizes(i),
-        name.map(_ + s"_exon-$exonNumber"), _originals = List(this))
+        Some(s"exon-$exonNumber"), _originals = List(this))
     })
   } else None
 
   lazy val introns = if (blockCount.isDefined && blockSizes.length > 0 && blockStarts.length > 0) {
-    Some(for (i <- 0 to (blockCount.get - 1)) yield {
+    Some(for (i <- 0 until (blockCount.get - 1)) yield {
       val intronNumber = strand match {
         case Some(false) => blockCount.get - i
         case _           => i + 1
       }
-      BedRecord(chr, start + start + blockStarts(i) + blockSizes(i) + 1, start + blockStarts(i + 1) - 1,
-        name.map(_ + s"_intron-$intronNumber"), _originals = List(this))
+      BedRecord(chr, start + blockStarts(i) + blockSizes(i) + 1, start + blockStarts(i + 1) - 1,
+        Some(s"intron-$intronNumber"), _originals = List(this))
     })
   } else None
 
@@ -56,8 +62,10 @@ case class BedRecord(chr: String,
   }
 
   lazy val utr3 = (strand, thickStart, thickEnd) match {
-    case (Some(false), Some(tStart), Some(tEnd)) => Some(BedRecord(chr, start, tStart - 1, name.map(_ + "_utr3")))
-    case (Some(true), Some(tStart), Some(tEnd)) => Some(BedRecord(chr, tEnd + 1, end, name.map(_ + "_utr3")))
+    case (Some(false), Some(tStart), Some(tEnd)) if (tStart > start && tEnd < end) =>
+      Some(BedRecord(chr, start, tStart - 1, name.map(_ + "_utr3")))
+    case (Some(true), Some(tStart), Some(tEnd)) if (tStart > start && tEnd < end) =>
+      Some(BedRecord(chr, tEnd + 1, end, name.map(_ + "_utr3")))
     case _ => None
   }
 
@@ -79,7 +87,7 @@ case class BedRecord(chr: String,
     require(start <= end, "Start is greater then end")
     (thickStart, thickEnd) match {
       case (Some(s), Some(e)) => require(s <= e, "Thick start is greater then end")
-      case _ =>
+      case _                  =>
     }
     blockCount match {
       case Some(count) => {
@@ -105,7 +113,7 @@ object BedRecord {
       values.lift(5).map {
         case "-" => false
         case "+" => true
-        case _ => throw new IllegalStateException("Strand (column 6) must be '+' or '-'")
+        case _   => throw new IllegalStateException("Strand (column 6) must be '+' or '-'")
       },
       values.lift(6).map(_.toInt),
       values.lift(7) map (_.toInt),
diff --git a/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/utils/intervals/BedRecordList.scala b/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/utils/intervals/BedRecordList.scala
index 351ab6b08..a0d9153de 100644
--- a/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/utils/intervals/BedRecordList.scala
+++ b/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/utils/intervals/BedRecordList.scala
@@ -25,7 +25,7 @@ class BedRecordList(val chrRecords: Map[String, List[BedRecord]], header: List[S
     .dropWhile(_.end < record.start)
     .takeWhile(_.start <= record.end)
 
-  def length = allRecords.foldLeft(0L)((a,b) => a + b.length)
+  def length = allRecords.foldLeft(0L)((a, b) => a + b.length)
 
   def squishBed(strandSensitive: Boolean = true) = BedRecordList.fromList {
     (for ((chr, records) <- sort.chrRecords; record <- records) yield {
@@ -37,15 +37,15 @@ class BedRecordList(val chrRecords: Map[String, List[BedRecord]], header: List[S
       } else {
         overlaps
           .foldLeft(List(record))((result, overlap) => {
-          (for (r <- result) yield {
-            (overlap.start < r.start, overlap.end > r.end) match {
-              case (true, true) => Nil
-              case (true, false) => List(r.copy(start = overlap.end + 1))
-              case (false, true) => List(r.copy(end = overlap.start - 1))
-              case (false, false) => List(r.copy(end = overlap.start - 1), r.copy(start = overlap.end + 1))
-            }
-          }).flatten
-        })
+            (for (r <- result) yield {
+              (overlap.start < r.start, overlap.end > r.end) match {
+                case (true, true)   => Nil
+                case (true, false)  => List(r.copy(start = overlap.end + 1))
+                case (false, true)  => List(r.copy(end = overlap.start - 1))
+                case (false, false) => List(r.copy(end = overlap.start - 1), r.copy(start = overlap.end + 1))
+              }
+            }).flatten
+          })
       }
     }).flatten
   }
@@ -77,7 +77,7 @@ class BedRecordList(val chrRecords: Map[String, List[BedRecord]], header: List[S
 
 object BedRecordList {
   def fromListWithHeader(records: Traversable[BedRecord],
-               header: List[String]): BedRecordList = fromListWithHeader(records.toIterator, header)
+                         header: List[String]): BedRecordList = fromListWithHeader(records.toIterator, header)
 
   def fromListWithHeader(records: TraversableOnce[BedRecord], header: List[String]): BedRecordList = {
     val map = mutable.Map[String, ListBuffer[BedRecord]]()
-- 
GitLab