Skip to content
Snippets Groups Projects
Commit cedd54be authored by Peter van 't Hof's avatar Peter van 't Hof
Browse files

Refactor RegionAfCount, 10-20 times faster now

parent 2b62e291
No related branches found
No related tags found
No related merge requests found
......@@ -38,19 +38,19 @@ import scala.math._
object RegionAfCount extends ToolCommand {
case class Args(bedFile: File = null,
outputFile: File = null,
scatterpPlot: Option[File] = None,
outputPrefix: String = null,
scatterpPlot: Boolean = false,
vcfFiles: List[File] = Nil) extends AbstractArgs
class OptParser extends AbstractOptParser {
opt[File]('b', "bedFile") required () maxOccurs 1 valueName "<file>" action { (x, c) =>
opt[File]('b', "bedFile") unbounded () required () maxOccurs 1 valueName "<file>" action { (x, c) =>
c.copy(bedFile = x)
}
opt[File]('o', "outputFile") required () maxOccurs 1 valueName "<file>" action { (x, c) =>
c.copy(outputFile = x)
opt[String]('o', "outputPrefix") unbounded () required () maxOccurs 1 valueName "<file prefix>" action { (x, c) =>
c.copy(outputPrefix = x)
}
opt[File]('s', "scatterPlot") maxOccurs 1 valueName "<file>" action { (x, c) =>
c.copy(scatterpPlot = Some(x))
opt[Unit]('s', "scatterPlot") unbounded () action { (x, c) =>
c.copy(scatterpPlot = true)
}
opt[File]('V', "vcfFile") unbounded () minOccurs 1 action { (x, c) =>
c.copy(vcfFiles = c.vcfFiles ::: x :: Nil)
......@@ -71,68 +71,94 @@ object RegionAfCount extends ToolCommand {
val combinedBedRecords = bedRecords.combineOverlap
logger.info(s"${combinedBedRecords.allRecords.size} left")
logger.info(s"${combinedBedRecords.allRecords.size * cmdArgs.vcfFiles.size} query's to do")
logger.info("Reading vcf files")
case class AfCounts(var names: Double = 0,
var namesExons: Double = 0,
var namesIntrons: Double = 0,
var namesCoding: Double = 0,
var utr: Double = 0,
var utr5: Double = 0,
var utr3: Double = 0,
var exons: Map[String, Double] = Map(),
var intron: Map[String, Double] = Map())
var c = 0
val afCountsRaw = for (region <- combinedBedRecords.allRecords.par) yield {
val sum = (for (vcfFile <- cmdArgs.vcfFiles.par) yield vcfFile -> {
val afCounts = mutable.Map[String, Double]()
val reader = new VCFFileReader(vcfFile, true)
val it = reader.query(region.chr, region.start, region.end)
for (variant <- it) {
val afCounts = (for (vcfFile <- cmdArgs.vcfFiles.par) yield vcfFile -> {
val reader = new VCFFileReader(vcfFile, true)
val afCounts: mutable.Map[String, AfCounts] = mutable.Map()
for (region <- combinedBedRecords.allRecords) yield {
val originals = region.originals()
for (variant <- reader.query(region.chr, region.start, region.end)) {
val sum = (variant.getAttribute("AF", 0) match {
case a: util.ArrayList[_] => a.map(_.toString.toDouble).toArray
case s => Array(s.toString.toDouble)
}).sum
region.originals()
.map(x => x.name.getOrElse(s"${x.chr}:${x.start}-${x.end}"))
.distinct
.foreach(name => afCounts += name -> (afCounts.getOrElse(name, 0.0) + sum))
val interval = BedRecord(variant.getContig, variant.getStart, variant.getEnd)
originals.foreach { x =>
val name = x.name.getOrElse(s"${x.chr}:${x.start}-${x.end}")
if (!afCounts.contains(name)) afCounts += name -> AfCounts()
afCounts(name).names += sum
val exons = x.exons.getOrElse(Seq()).filter(_.overlapWith(interval))
val introns = x.introns.getOrElse(Seq()).filter(_.overlapWith(interval))
val utr5 = x.utr5.map(_.overlapWith(interval))
val utr3 = x.utr3.map(_.overlapWith(interval))
if (exons.nonEmpty) {
afCounts(name).namesExons += sum
if (!utr5.getOrElse(false) && !utr3.getOrElse(false)) afCounts(name).namesCoding += sum
}
if (introns.nonEmpty) afCounts(name).namesIntrons += sum
if (utr5.getOrElse(false) || utr3.getOrElse(false)) afCounts(name).utr += sum
if (utr5.getOrElse(false)) afCounts(name).utr5 += sum
if (utr3.getOrElse(false)) afCounts(name).utr3 += sum
}
}
reader.close()
afCounts.toMap
}).toMap
c += 1
if (c % 100 == 0) logger.info(s"$c regions done")
c += 1
if (c % 100 == 0) logger.info(s"$c regions done")
}
afCounts.toMap
}).toMap
sum
}
logger.info(s"Done reading, ${c} regions")
logger.info(s"Done reading, $c regions")
logger.info("Writing output files")
val afCounts: Map[String, Map[File, Double]] = {
val combinedAfCounts: mutable.Map[String, mutable.Map[File, Double]] = mutable.Map()
for (x <- afCountsRaw.toList; (file, counts) <- x.toList; (name, count) <- counts) {
val map = combinedAfCounts.getOrElse(name, mutable.Map())
map += file -> (map.getOrElse(file, 0.0) + count)
combinedAfCounts += name -> map
def writeOutput(tsvFile: File, function: AfCounts => Double): Unit = {
val writer = new PrintWriter(tsvFile)
writer.println("\t" + cmdArgs.vcfFiles.map(_.getName).mkString("\t"))
for (r <- cmdArgs.vcfFiles.foldLeft(Set[String]())((a, b) => a ++ afCounts(b).keySet)) {
writer.print(r + "\t")
writer.println(cmdArgs.vcfFiles.map(x => function(afCounts(x).getOrElse(r, AfCounts()))).mkString("\t"))
}
combinedAfCounts.map(x => x._1 -> x._2.toMap).toMap
}
logger.info("Writing output file")
writer.close()
val writer = new PrintWriter(cmdArgs.outputFile)
writer.println("\t" + cmdArgs.vcfFiles.map(_.getName).mkString("\t"))
for (r <- afCounts.keys) {
writer.print(r + "\t")
writer.println(cmdArgs.vcfFiles.map(afCounts(r).getOrElse(_, 0.0)).mkString("\t"))
if (cmdArgs.scatterpPlot) generatePlot(tsvFile)
}
writer.close()
cmdArgs.scatterpPlot.foreach { scatterPlotFile =>
logger.info("Generate plot")
def generatePlot(tsvFile: File): Unit = {
logger.info(s"Generate plot for $tsvFile")
val scatterPlot = new ScatterPlot(null)
scatterPlot.input = cmdArgs.outputFile
scatterPlot.output = scatterPlotFile
scatterPlot.input = tsvFile
scatterPlot.output = new File(tsvFile.getAbsolutePath.stripSuffix(".tsv") + ".png")
scatterPlot.ylabel = Some("Sum of AFs")
scatterPlot.width = Some(1200)
scatterPlot.height = Some(1000)
scatterPlot.runLocal()
}
for (
arg <- List[(File, AfCounts => Double)](
(new File(cmdArgs.outputPrefix + ".names.tsv"), _.names),
(new File(cmdArgs.outputPrefix + ".names.exons_only.tsv"), _.namesExons),
(new File(cmdArgs.outputPrefix + ".names.introns_only.tsv"), _.namesIntrons),
(new File(cmdArgs.outputPrefix + ".names.coding.tsv"), _.namesCoding),
(new File(cmdArgs.outputPrefix + ".names.utr.tsv"), _.utr),
(new File(cmdArgs.outputPrefix + ".names.utr5.tsv"), _.utr5),
(new File(cmdArgs.outputPrefix + ".names.utr3.tsv"), _.utr3)
).par
) writeOutput(arg._1, arg._2)
logger.info("Done")
}
......
......@@ -23,27 +23,33 @@ case class BedRecord(chr: String,
else _originals
}
def overlapWith(record: BedRecord): Boolean = {
if (chr != record.chr) false
else if (start <= record.end && record.start <= end) true
else false
}
def length = end - start + 1
lazy val exons = if (blockCount.isDefined && blockSizes.length > 0 && blockStarts.length > 0) {
Some(for (i <- 0 to blockCount.get) yield {
Some(for (i <- 0 until blockCount.get) yield {
val exonNumber = strand match {
case Some(false) => blockCount.get - i
case _ => i + 1
}
BedRecord(chr, start + blockStarts(i), start + blockStarts(i) + blockSizes(i),
name.map(_ + s"_exon-$exonNumber"), _originals = List(this))
Some(s"exon-$exonNumber"), _originals = List(this))
})
} else None
lazy val introns = if (blockCount.isDefined && blockSizes.length > 0 && blockStarts.length > 0) {
Some(for (i <- 0 to (blockCount.get - 1)) yield {
Some(for (i <- 0 until (blockCount.get - 1)) yield {
val intronNumber = strand match {
case Some(false) => blockCount.get - i
case _ => i + 1
}
BedRecord(chr, start + start + blockStarts(i) + blockSizes(i) + 1, start + blockStarts(i + 1) - 1,
name.map(_ + s"_intron-$intronNumber"), _originals = List(this))
BedRecord(chr, start + blockStarts(i) + blockSizes(i) + 1, start + blockStarts(i + 1) - 1,
Some(s"intron-$intronNumber"), _originals = List(this))
})
} else None
......@@ -56,8 +62,10 @@ case class BedRecord(chr: String,
}
lazy val utr3 = (strand, thickStart, thickEnd) match {
case (Some(false), Some(tStart), Some(tEnd)) => Some(BedRecord(chr, start, tStart - 1, name.map(_ + "_utr3")))
case (Some(true), Some(tStart), Some(tEnd)) => Some(BedRecord(chr, tEnd + 1, end, name.map(_ + "_utr3")))
case (Some(false), Some(tStart), Some(tEnd)) if (tStart > start && tEnd < end) =>
Some(BedRecord(chr, start, tStart - 1, name.map(_ + "_utr3")))
case (Some(true), Some(tStart), Some(tEnd)) if (tStart > start && tEnd < end) =>
Some(BedRecord(chr, tEnd + 1, end, name.map(_ + "_utr3")))
case _ => None
}
......@@ -79,7 +87,7 @@ case class BedRecord(chr: String,
require(start <= end, "Start is greater then end")
(thickStart, thickEnd) match {
case (Some(s), Some(e)) => require(s <= e, "Thick start is greater then end")
case _ =>
case _ =>
}
blockCount match {
case Some(count) => {
......@@ -105,7 +113,7 @@ object BedRecord {
values.lift(5).map {
case "-" => false
case "+" => true
case _ => throw new IllegalStateException("Strand (column 6) must be '+' or '-'")
case _ => throw new IllegalStateException("Strand (column 6) must be '+' or '-'")
},
values.lift(6).map(_.toInt),
values.lift(7) map (_.toInt),
......
......@@ -25,7 +25,7 @@ class BedRecordList(val chrRecords: Map[String, List[BedRecord]], header: List[S
.dropWhile(_.end < record.start)
.takeWhile(_.start <= record.end)
def length = allRecords.foldLeft(0L)((a,b) => a + b.length)
def length = allRecords.foldLeft(0L)((a, b) => a + b.length)
def squishBed(strandSensitive: Boolean = true) = BedRecordList.fromList {
(for ((chr, records) <- sort.chrRecords; record <- records) yield {
......@@ -37,15 +37,15 @@ class BedRecordList(val chrRecords: Map[String, List[BedRecord]], header: List[S
} else {
overlaps
.foldLeft(List(record))((result, overlap) => {
(for (r <- result) yield {
(overlap.start < r.start, overlap.end > r.end) match {
case (true, true) => Nil
case (true, false) => List(r.copy(start = overlap.end + 1))
case (false, true) => List(r.copy(end = overlap.start - 1))
case (false, false) => List(r.copy(end = overlap.start - 1), r.copy(start = overlap.end + 1))
}
}).flatten
})
(for (r <- result) yield {
(overlap.start < r.start, overlap.end > r.end) match {
case (true, true) => Nil
case (true, false) => List(r.copy(start = overlap.end + 1))
case (false, true) => List(r.copy(end = overlap.start - 1))
case (false, false) => List(r.copy(end = overlap.start - 1), r.copy(start = overlap.end + 1))
}
}).flatten
})
}
}).flatten
}
......@@ -77,7 +77,7 @@ class BedRecordList(val chrRecords: Map[String, List[BedRecord]], header: List[S
object BedRecordList {
def fromListWithHeader(records: Traversable[BedRecord],
header: List[String]): BedRecordList = fromListWithHeader(records.toIterator, header)
header: List[String]): BedRecordList = fromListWithHeader(records.toIterator, header)
def fromListWithHeader(records: TraversableOnce[BedRecord], header: List[String]): BedRecordList = {
val map = mutable.Map[String, ListBuffer[BedRecord]]()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment