From 879c66bb715c05c9bcb3452230e3e6fc291e76be Mon Sep 17 00:00:00 2001
From: bow <bow@bow.web.id>
Date: Thu, 30 Oct 2014 12:12:15 +0100
Subject: [PATCH] Use mocks to avoid writing to disk when testing

---
 .../nl/lumc/sasc/biopet/tools/WipeReads.scala | 55 ++++++-------
 .../sasc/biopet/tools/WipeReadsUnitTest.scala | 81 ++++++++-----------
 2 files changed, 59 insertions(+), 77 deletions(-)

diff --git a/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/WipeReads.scala b/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/WipeReads.scala
index 19054821b..a6c82c3b1 100644
--- a/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/WipeReads.scala
+++ b/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/WipeReads.scala
@@ -13,6 +13,7 @@ import htsjdk.samtools.SamReader
 import htsjdk.samtools.SamReaderFactory
 import htsjdk.samtools.QueryInterval
 import htsjdk.samtools.ValidationStringency
+import htsjdk.samtools.SAMFileWriter
 import htsjdk.samtools.SAMFileWriterFactory
 import htsjdk.samtools.SAMRecord
 import htsjdk.samtools.util.{ Interval, IntervalTreeMap }
@@ -51,7 +52,7 @@ object WipeReads extends ToolCommand {
    * @param inBam input BAM file
    * @return
    */
-  private def prepBam(inBam: File): SamReader = {
+  private def prepInBam(inBam: File): SamReader = {
     val bam = SamReaderFactory
       .make()
       .validationStringency(ValidationStringency.LENIENT)
@@ -60,6 +61,13 @@ object WipeReads extends ToolCommand {
     bam
   }
 
+  private def prepOutBam(outBam: File, templateBam: File,
+                         writeIndex: Boolean = true, async: Boolean = true): SAMFileWriter =
+    new SAMFileWriterFactory()
+      .setCreateIndex(writeIndex)
+      .setUseAsyncIo(async)
+      .makeBAMWriter(prepInBam(templateBam).getFileHeader, true, outBam)
+
   /**
    * Creates a list of intervals given an input File
    *
@@ -208,7 +216,7 @@ object WipeReads extends ToolCommand {
       else
         (r: SAMRecord) => readGroupIds.contains(r.getReadGroup.getReadGroupId)
 
-    val readyBam = prepBam(inBam)
+    val readyBam = prepInBam(inBam)
 
     val queryIntervals = ivl
       .flatMap(x => makeQueryInterval(readyBam, x))
@@ -261,44 +269,29 @@ object WipeReads extends ToolCommand {
    * @param filterFunc filter function that evaluates true for excluded SAMRecord
    * @param inBam input BAM file
    * @param outBam output BAM file
-   * @param writeIndex whether to write index for output BAM file
-   * @param async whether to write asynchronously
    * @param filteredOutBam whether to write excluded SAMRecords to their own BAM file
    */
-  def writeFilteredBam(filterFunc: (SAMRecord => Boolean), inBam: File, outBam: File,
-                       writeIndex: Boolean = true, async: Boolean = true,
-                       filteredOutBam: Option[File] = None) = {
-
-    val factory = new SAMFileWriterFactory()
-      .setCreateIndex(writeIndex)
-      .setUseAsyncIo(async)
-
-    val templateBam = prepBam(inBam)
-    val targetBam = factory.makeBAMWriter(templateBam.getFileHeader, true, outBam)
-    val filteredBam = filteredOutBam
-      .map(x => factory.makeBAMWriter(templateBam.getFileHeader, true, x))
+  def writeFilteredBam(filterFunc: (SAMRecord => Boolean), inBam: SamReader, outBam: SAMFileWriter,
+                       filteredOutBam: Option[SAMFileWriter] = None) = {
 
     logger.info("Writing output file(s) ...")
     try {
       var (incl, excl) = (0, 0)
-      for (rec <- templateBam.asScala) {
+      for (rec <- inBam.asScala) {
         if (!filterFunc(rec)) {
-          targetBam.addAlignment(rec)
+          outBam.addAlignment(rec)
           incl += 1
         } else {
           excl += 1
-          filteredBam match {
-            case None     =>
-            case Some(x)  => x.addAlignment(rec)
-          }
+          filteredOutBam.foreach(x => x.addAlignment(rec))
         }
       }
-      println(List("input_bam", "output_bam", "count_included", "count_excluded").mkString("\t"))
-      println(List(inBam.getName, outBam.getName, incl, excl).mkString("\t"))
+      println(List("count_included", "count_excluded").mkString("\t"))
+      println(List(incl, excl).mkString("\t"))
     } finally {
-      templateBam.close()
-      targetBam.close()
-      filteredBam.foreach(x => x.close())
+      inBam.close()
+      outBam.close()
+      filteredOutBam.foreach(x => x.close())
     }
   }
 
@@ -391,6 +384,7 @@ object WipeReads extends ToolCommand {
 
     val commandArgs: Args = parseArgs(args)
 
+    // cannot use SamReader as inBam directly since it only allows one active iterator at any given time
     val filterFunc = makeFilterNotFunction(
       ivl = makeIntervalFromFile(commandArgs.targetRegions),
       inBam = commandArgs.inputBam,
@@ -403,10 +397,9 @@ object WipeReads extends ToolCommand {
 
     writeFilteredBam(
       filterFunc,
-      commandArgs.inputBam,
-      commandArgs.outputBam,
-      writeIndex = !commandArgs.noMakeIndex,
-      filteredOutBam = commandArgs.filteredOutBam
+      prepInBam(commandArgs.inputBam),
+      prepOutBam(commandArgs.outputBam, commandArgs.inputBam, writeIndex = !commandArgs.noMakeIndex),
+      commandArgs.filteredOutBam.map(x => prepOutBam(x, commandArgs.inputBam, writeIndex = !commandArgs.noMakeIndex))
     )
   }
 }
diff --git a/biopet-framework/src/test/scala/nl/lumc/sasc/biopet/tools/WipeReadsUnitTest.scala b/biopet-framework/src/test/scala/nl/lumc/sasc/biopet/tools/WipeReadsUnitTest.scala
index 15ec3a3c4..772ed6295 100644
--- a/biopet-framework/src/test/scala/nl/lumc/sasc/biopet/tools/WipeReadsUnitTest.scala
+++ b/biopet-framework/src/test/scala/nl/lumc/sasc/biopet/tools/WipeReadsUnitTest.scala
@@ -9,6 +9,7 @@ import java.nio.file.Paths
 import scala.collection.JavaConverters._
 
 import htsjdk.samtools.SAMFileHeader
+import htsjdk.samtools.SAMFileWriter
 import htsjdk.samtools.SAMLineParser
 import htsjdk.samtools.SAMReadGroupRecord
 import htsjdk.samtools.SAMRecord
@@ -18,10 +19,13 @@ import htsjdk.samtools.SamReaderFactory
 import htsjdk.samtools.ValidationStringency
 import htsjdk.samtools.util.Interval
 import org.scalatest.Matchers
+import org.mockito.Matchers._
+import org.mockito.Mockito.{ inOrder => inOrd, times, verify }
+import org.scalatest.mock.MockitoSugar
 import org.scalatest.testng.TestNGSuite
 import org.testng.annotations.Test
 
-class WipeReadsUnitTest extends TestNGSuite with Matchers {
+class WipeReadsUnitTest extends TestNGSuite with MockitoSugar with Matchers {
 
   import WipeReads._
 
@@ -371,82 +375,67 @@ class WipeReadsUnitTest extends TestNGSuite with Matchers {
     filterNotFunc(pBamRecs2(8)) shouldBe false
     filterNotFunc(pBamRecs2(9)) shouldBe false
   }
-
   @Test def testWriteSingleBamDefault() = {
     val mockFilterOutFunc = (r: SAMRecord) => Set("r03", "r04", "r05").contains(r.getReadName)
-    val outBam = makeTempBam()
-    val outBamIndex = makeTempBamIndex(outBam)
-    outBam.deleteOnExit()
-    outBamIndex.deleteOnExit()
+    val outBam = mock[SAMFileWriter]
 
     val stdout = new java.io.ByteArrayOutputStream
     Console.withOut(stdout) {
-      writeFilteredBam(mockFilterOutFunc, sBamFile1, outBam)
+      writeFilteredBam(mockFilterOutFunc, makeSamReader(sBamFile1), outBam)
     }
     stdout.toString should ===(
-      "input_bam\toutput_bam\tcount_included\tcount_excluded\n%s\t%s\t%d\t%d\n"
-        .format(sBamFile1.getName, outBam.getName, 4, 3)
+      "count_included\tcount_excluded\n%d\t%d\n"
+        .format(4, 3)
     )
 
-    val exp = makeSamReader(sBamFile3).asScala
-    val obs = makeSamReader(outBam).asScala
-    for ((e, o) <- exp.zip(obs))
-      e.getSAMString should ===(o.getSAMString)
-    outBam should be('exists)
-    outBamIndex should be('exists)
+    val exp = makeSamReader(sBamFile3).asScala.toList
+    verify(outBam, times(4)).addAlignment(anyObject.asInstanceOf[SAMRecord])
+    val obs = inOrd(outBam)
+    exp.foreach(x => {
+      obs.verify(outBam).addAlignment(x)
+    })
   }
 
   @Test def testWriteSingleBamAndFilteredBAM() = {
     val mockFilterOutFunc = (r: SAMRecord) => Set("r03", "r04", "r05").contains(r.getReadName)
-    val outBam = makeTempBam()
-    val outBamIndex = makeTempBamIndex(outBam)
-    outBam.deleteOnExit()
-    outBamIndex.deleteOnExit()
-    val filteredOutBam = makeTempBam()
-    val filteredOutBamIndex = makeTempBamIndex(filteredOutBam)
-    filteredOutBam.deleteOnExit()
-    filteredOutBamIndex.deleteOnExit()
+    val outBam = mock[SAMFileWriter]
+    val filtBam = Some(mock[SAMFileWriter])
 
     val stdout = new java.io.ByteArrayOutputStream
     Console.withOut(stdout) {
-      writeFilteredBam(mockFilterOutFunc, sBamFile1, outBam, filteredOutBam = Some(filteredOutBam))
+      writeFilteredBam(mockFilterOutFunc, makeSamReader(sBamFile1), outBam, filteredOutBam = filtBam)
     }
     stdout.toString should ===(
-      "input_bam\toutput_bam\tcount_included\tcount_excluded\n%s\t%s\t%d\t%d\n"
-        .format(sBamFile1.getName, outBam.getName, 4, 3)
+      "count_included\tcount_excluded\n%d\t%d\n"
+        .format(4, 3)
     )
 
     val exp = makeSamReader(sBamFile4).asScala
-    val obs = makeSamReader(filteredOutBam).asScala
-    for ((e, o) <- exp.zip(obs))
-      e.getSAMString should ===(o.getSAMString)
-    outBam should be('exists)
-    outBamIndex should be('exists)
-    filteredOutBam should be('exists)
-    filteredOutBamIndex should be('exists)
+    verify(filtBam.get, times(3)).addAlignment(anyObject.asInstanceOf[SAMRecord])
+    val obs = inOrd(filtBam.get)
+    exp.foreach(x => {
+      obs.verify(filtBam.get).addAlignment(x)
+    })
   }
 
   @Test def testWritePairBamDefault() = {
     val mockFilterOutFunc = (r: SAMRecord) => Set("r03", "r04", "r05").contains(r.getReadName)
-    val outBam = makeTempBam()
-    val outBamIndex = makeTempBamIndex(outBam)
-    outBam.deleteOnExit()
-    outBamIndex.deleteOnExit()
+    val outBam = mock[SAMFileWriter]
 
     val stdout = new java.io.ByteArrayOutputStream
     Console.withOut(stdout) {
-      writeFilteredBam(mockFilterOutFunc, pBamFile1, outBam)
+      writeFilteredBam(mockFilterOutFunc, makeSamReader(pBamFile1), outBam)
     }
     stdout.toString should ===(
-      "input_bam\toutput_bam\tcount_included\tcount_excluded\n%s\t%s\t%d\t%d\n"
-        .format(pBamFile1.getName, outBam.getName, 8, 6)
+      "count_included\tcount_excluded\n%d\t%d\n"
+        .format(8, 6)
     )
-    val exp = makeSamReader(pBamFile3).asScala
-    val obs = makeSamReader(outBam).asScala
-    for ((e, o) <- exp.zip(obs))
-      e.getSAMString should ===(o.getSAMString)
-    outBam should be('exists)
-    outBamIndex should be('exists)
+    val exp = makeSamReader(pBamFile3).asScala.toList
+    verify(outBam, times(8)).addAlignment(anyObject.asInstanceOf[SAMRecord])
+    val obs = inOrd(outBam)
+    exp.foreach(x => {
+      obs.verify(outBam).addAlignment(x)
+    })
   }
 
   @Test def testArgsMinimum() = {
-- 
GitLab