From b178746f0136dc2a233f9dbe6f8b5add8bcd9a00 Mon Sep 17 00:00:00 2001
From: Peter van 't Hof <p.j.van_t_hof@lumc.nl>
Date: Tue, 3 Feb 2015 09:59:33 +0100
Subject: [PATCH] Prep for paralyzation

---
 .../nl/lumc/sasc/biopet/tools/VcfStats.scala  | 138 +++++++++++++-----
 1 file changed, 98 insertions(+), 40 deletions(-)

diff --git a/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/VcfStats.scala b/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/VcfStats.scala
index b6022b7f6..1602c7e76 100644
--- a/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/VcfStats.scala
+++ b/public/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/VcfStats.scala
@@ -2,7 +2,7 @@ package nl.lumc.sasc.biopet.tools
 
 import java.io.{ FileOutputStream, PrintWriter, File }
 
-import htsjdk.variant.variantcontext.Genotype
+import htsjdk.variant.variantcontext.{ VariantContext, Genotype }
 import htsjdk.variant.vcf.VCFFileReader
 import nl.lumc.sasc.biopet.core.ToolCommand
 import org.broadinstitute.gatk.utils.R.RScriptExecutor
@@ -31,10 +31,65 @@ object VcfStats extends ToolCommand {
     }
   }
 
-  val genotypeOverlap: mutable.Map[String, mutable.Map[String, Int]] = mutable.Map()
-  val allelesOverlap: mutable.Map[String, mutable.Map[String, Int]] = mutable.Map()
-  val qualStats: mutable.Map[Any, Int] = mutable.Map()
-  val genotypeStats: mutable.Map[String, mutable.Map[String, mutable.Map[Any, Int]]] = mutable.Map()
+  class SampleToSampleStats {
+    var genotypeOverlap: Int = 0
+    var alleleOverlap: Int = 0
+
+    def +=(other: SampleToSampleStats) {
+      this.genotypeOverlap += other.genotypeOverlap
+      this.alleleOverlap += other.alleleOverlap
+    }
+  }
+
+  class SampleStats {
+    val genotypeStats: mutable.Map[String, mutable.Map[Any, Int]] = mutable.Map()
+    val sampleToSample: mutable.Map[String, SampleToSampleStats] = mutable.Map()
+
+    def +=(other: SampleStats): Unit = {
+      for ((key, value) <- other.sampleToSample) {
+        if (this.sampleToSample.contains(key)) this.sampleToSample(key) += value
+        else this.sampleToSample(key) = value
+      }
+      for ((field, fieldMap) <- other.genotypeStats) {
+        val thisField = this.genotypeStats.get(field)
+        if (thisField.isDefined) mergeStatsMap(thisField.get, fieldMap)
+        else this.genotypeStats += field -> fieldMap
+      }
+    }
+  }
+
+  class Stats {
+    val generalStats: mutable.Map[String, mutable.Map[Any, Int]] = mutable.Map()
+    val samplesStats: mutable.Map[String, SampleStats] = mutable.Map()
+
+    def +=(other: Stats): Unit = {
+      for ((key, value) <- other.samplesStats) {
+        if (this.samplesStats.contains(key)) this.samplesStats(key) += value
+        else this.samplesStats(key) = value
+      }
+      for ((field, fieldMap) <- other.generalStats) {
+        val thisField = this.generalStats.get(field)
+        if (thisField.isDefined) mergeStatsMap(thisField.get, fieldMap)
+        else this.generalStats += field -> fieldMap
+      }
+    }
+  }
+
+  def mergeStatsMap(m1: mutable.Map[Any, Int], m2: mutable.Map[Any, Int]): Unit = {
+    for (key <- m2.keySet)
+      m1(key) = m1.getOrElse(key, 0) + m2(key)
+  }
+
+  def mergeNestedStatsMap(m1: mutable.Map[String, mutable.Map[Any, Int]], m2: Map[String, Map[Any, Int]]): Unit = {
+    for ((field, fieldMap) <- m2) {
+      if (m1.contains(field)) {
+        for ((key, value) <- fieldMap) {
+          if (m1(field).contains(key)) m1(field)(key) += value
+          else m1(field)(key) = value
+        }
+      } else m1(field) = mutable.Map(fieldMap.toList: _*)
+    }
+  }
 
   var commandArgs: Args = _
 
@@ -47,74 +102,75 @@ object VcfStats extends ToolCommand {
     commandArgs = argsParser.parse(args, Args()) getOrElse sys.exit(1)
 
     val reader = new VCFFileReader(commandArgs.inputFile, false)
+
     val header = reader.getFileHeader
     val samples = header.getSampleNamesInOrder.toList
 
-    // Init
+    // Reading vcf records
+    logger.info("Start reading vcf records")
+    var counter = 0
+    val stats = new Stats
+    //init stats
     for (sample1 <- samples) {
-      genotypeOverlap(sample1) = mutable.Map()
-      allelesOverlap(sample1) = mutable.Map()
-      genotypeStats(sample1) = mutable.Map()
+      stats.samplesStats += sample1 -> new SampleStats
       for (sample2 <- samples) {
-        genotypeOverlap(sample1)(sample2) = 0
-        allelesOverlap(sample1)(sample2) = 0
+        stats.samplesStats(sample1).sampleToSample += sample2 -> new SampleToSampleStats
       }
     }
-
-    // Reading vcf records
-    logger.info("Start reading vcf records")
-    var counter = 0
-    for (record <- reader) {
-      qualStats(record.getPhredScaledQual) = qualStats.getOrElse(record.getPhredScaledQual, 0) + 1
-      for (sample1 <- samples) {
+    for (record <- reader) yield {
+      mergeNestedStatsMap(stats.generalStats, checkGeneral(record))
+      for (sample1 <- samples) yield {
         val genotype = record.getGenotype(sample1)
-        checkGenotype(genotype)
+        mergeNestedStatsMap(stats.samplesStats(sample1).genotypeStats, checkGenotype(genotype))
         for (sample2 <- samples) {
           val genotype2 = record.getGenotype(sample2)
           if (genotype.getAlleles == genotype2.getAlleles)
-            genotypeOverlap(sample1)(sample2) = genotypeOverlap(sample1)(sample2) + 1
-          for (allele <- genotype.getAlleles)
-            if (genotype2.getAlleles.exists(_.basesMatch(allele)))
-              allelesOverlap(sample1)(sample2) = allelesOverlap(sample1)(sample2) + 1
+            stats.samplesStats(sample1).sampleToSample(sample2).genotypeOverlap += 1
+          stats.samplesStats(sample1).sampleToSample(sample2).alleleOverlap += genotype.getAlleles.count(allele => genotype2.getAlleles.exists(_.basesMatch(allele)))
         }
       }
+
       counter += 1
       if (counter % 100000 == 0) logger.info(counter + " variants done")
     }
+
     logger.info(counter + " variants done")
     logger.info("Done reading vcf records")
 
-    plotXy(writeField("QUAL", qualStats.toMap))
-    writeGenotypeFields(commandArgs.outputDir + "/genotype_", samples)
-    writeOverlap(genotypeOverlap, commandArgs.outputDir + "/sample_compare/genotype_overlap", samples)
-    writeOverlap(allelesOverlap, commandArgs.outputDir + "/sample_compare/allele_overlap", samples)
+    plotXy(writeField("QUAL", stats.generalStats.getOrElse("QUAL", mutable.Map())))
+    writeGenotypeFields(stats, commandArgs.outputDir + "/genotype_", samples)
+    writeOverlap(stats, _.genotypeOverlap, commandArgs.outputDir + "/sample_compare/genotype_overlap", samples)
+    writeOverlap(stats, _.alleleOverlap, commandArgs.outputDir + "/sample_compare/allele_overlap", samples)
 
     logger.info("Done")
   }
 
-  def checkGenotype(genotype: Genotype): Unit = {
+  def checkGeneral(record: VariantContext): Map[String, Map[Any, Int]] = {
+    val qual = record.getPhredScaledQual
+    Map("QUAL" -> Map(qual -> 1))
+  }
+
+  def checkGenotype(genotype: Genotype): Map[String, Map[Any, Int]] = {
     val sample = genotype.getSampleName
     val dp = if (genotype.hasDP) genotype.getDP else "not set"
-    if (!genotypeStats(sample).contains("DP")) genotypeStats(sample)("DP") = mutable.Map()
-    genotypeStats(sample)("DP")(dp) = genotypeStats(sample)("DP").getOrElse(dp, 0) + 1
-
     val gq = if (genotype.hasGQ) genotype.getGQ else "not set"
-    if (!genotypeStats(sample).contains("GQ")) genotypeStats(sample)("GQ") = mutable.Map()
-    genotypeStats(sample)("GQ")(gq) = genotypeStats(sample)("GQ").getOrElse(gq, 0) + 1
 
     //TODO: add AD field
+
+    Map("DP" -> Map(dp -> 1),
+      "GQ" -> Map(gq -> 1))
   }
 
-  def writeGenotypeFields(prefix: String, samples: List[String]) {
+  def writeGenotypeFields(stats: Stats, prefix: String, samples: List[String]) {
     val fields = List("DP", "GQ")
     for (field <- fields) {
       val file = new File(prefix + field + ".tsv")
       file.getParentFile.mkdirs()
       val writer = new PrintWriter(file)
       writer.println(samples.mkString("\t", "\t", ""))
-      val keySet = (for (sample <- samples) yield genotypeStats(sample)(field).keySet).fold(Set[Any]())(_ ++ _)
+      val keySet = (for (sample <- samples) yield stats.samplesStats(sample).genotypeStats(field).keySet).fold(Set[Any]())(_ ++ _)
       for (key <- keySet.toList.sortWith(sortAnyAny(_, _))) {
-        val values = for (sample <- samples) yield genotypeStats(sample)(field).getOrElse(key, 0)
+        val values = for (sample <- samples) yield stats.samplesStats(sample).genotypeStats(field).getOrElse(key, 0)
         writer.println(values.mkString(key + "\t", "\t", ""))
       }
       writer.close()
@@ -122,7 +178,7 @@ object VcfStats extends ToolCommand {
     }
   }
 
-  def writeField(prefix: String, data: Map[Any, Int]): File = {
+  def writeField(prefix: String, data: mutable.Map[Any, Int]): File = {
     val file = new File(commandArgs.outputDir + "/" + prefix + ".tsv")
     println(file)
     file.getParentFile.mkdirs()
@@ -148,7 +204,8 @@ object VcfStats extends ToolCommand {
     }
   }
 
-  def writeOverlap(overlap: mutable.Map[String, mutable.Map[String, Int]], prefix: String, samples: List[String]): Unit = {
+  def writeOverlap(stats: Stats, function: SampleToSampleStats => Int,
+                   prefix: String, samples: List[String]): Unit = {
     val absFile = new File(prefix + ".abs.tsv")
     val relFile = new File(prefix + ".rel.tsv")
 
@@ -160,10 +217,11 @@ object VcfStats extends ToolCommand {
     absWriter.println(samples.mkString("\t", "\t", ""))
     relWriter.println(samples.mkString("\t", "\t", ""))
     for (sample1 <- samples) {
-      val values = for (sample2 <- samples) yield overlap.getOrElse(sample1, mutable.Map()).getOrElse(sample2, 0)
+      val values = for (sample2 <- samples) yield function(stats.samplesStats(sample1).sampleToSample(sample2))
+
       absWriter.println(values.mkString(sample1 + "\t", "\t", ""))
 
-      val total = overlap.getOrElse(sample1, mutable.Map()).getOrElse(sample1, 0)
+      val total = function(stats.samplesStats(sample1).sampleToSample(sample1))
       relWriter.println(values.map(_.toFloat / total).mkString(sample1 + "\t", "\t", ""))
     }
     absWriter.close()
-- 
GitLab