From 879c66bb715c05c9bcb3452230e3e6fc291e76be Mon Sep 17 00:00:00 2001 From: bow <bow@bow.web.id> Date: Thu, 30 Oct 2014 12:12:15 +0100 Subject: [PATCH] Use mocks to avoid writing to disk when testing --- .../nl/lumc/sasc/biopet/tools/WipeReads.scala | 55 ++++++------- .../sasc/biopet/tools/WipeReadsUnitTest.scala | 81 ++++++++----------- 2 files changed, 59 insertions(+), 77 deletions(-) diff --git a/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/WipeReads.scala b/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/WipeReads.scala index 19054821b..a6c82c3b1 100644 --- a/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/WipeReads.scala +++ b/biopet-framework/src/main/scala/nl/lumc/sasc/biopet/tools/WipeReads.scala @@ -13,6 +13,7 @@ import htsjdk.samtools.SamReader import htsjdk.samtools.SamReaderFactory import htsjdk.samtools.QueryInterval import htsjdk.samtools.ValidationStringency +import htsjdk.samtools.SAMFileWriter import htsjdk.samtools.SAMFileWriterFactory import htsjdk.samtools.SAMRecord import htsjdk.samtools.util.{ Interval, IntervalTreeMap } @@ -51,7 +52,7 @@ object WipeReads extends ToolCommand { * @param inBam input BAM file * @return */ - private def prepBam(inBam: File): SamReader = { + private def prepInBam(inBam: File): SamReader = { val bam = SamReaderFactory .make() .validationStringency(ValidationStringency.LENIENT) @@ -60,6 +61,13 @@ object WipeReads extends ToolCommand { bam } + private def prepOutBam(outBam: File, templateBam: File, + writeIndex: Boolean = true, async: Boolean = true): SAMFileWriter = + new SAMFileWriterFactory() + .setCreateIndex(writeIndex) + .setUseAsyncIo(async) + .makeBAMWriter(prepInBam(templateBam).getFileHeader, true, outBam) + /** * Creates a list of intervals given an input File * @@ -208,7 +216,7 @@ object WipeReads extends ToolCommand { else (r: SAMRecord) => readGroupIds.contains(r.getReadGroup.getReadGroupId) - val readyBam = prepBam(inBam) + val readyBam = prepInBam(inBam) val queryIntervals = ivl .flatMap(x => makeQueryInterval(readyBam, x)) @@ -261,44 +269,29 @@ object WipeReads extends ToolCommand { * @param filterFunc filter function that evaluates true for excluded SAMRecord * @param inBam input BAM file * @param outBam output BAM file - * @param writeIndex whether to write index for output BAM file - * @param async whether to write asynchronously * @param filteredOutBam whether to write excluded SAMRecords to their own BAM file */ - def writeFilteredBam(filterFunc: (SAMRecord => Boolean), inBam: File, outBam: File, - writeIndex: Boolean = true, async: Boolean = true, - filteredOutBam: Option[File] = None) = { - - val factory = new SAMFileWriterFactory() - .setCreateIndex(writeIndex) - .setUseAsyncIo(async) - - val templateBam = prepBam(inBam) - val targetBam = factory.makeBAMWriter(templateBam.getFileHeader, true, outBam) - val filteredBam = filteredOutBam - .map(x => factory.makeBAMWriter(templateBam.getFileHeader, true, x)) + def writeFilteredBam(filterFunc: (SAMRecord => Boolean), inBam: SamReader, outBam: SAMFileWriter, + filteredOutBam: Option[SAMFileWriter] = None) = { logger.info("Writing output file(s) ...") try { var (incl, excl) = (0, 0) - for (rec <- templateBam.asScala) { + for (rec <- inBam.asScala) { if (!filterFunc(rec)) { - targetBam.addAlignment(rec) + outBam.addAlignment(rec) incl += 1 } else { excl += 1 - filteredBam match { - case None => - case Some(x) => x.addAlignment(rec) - } + filteredOutBam.foreach(x => x.addAlignment(rec)) } } - println(List("input_bam", "output_bam", "count_included", "count_excluded").mkString("\t")) - println(List(inBam.getName, outBam.getName, incl, excl).mkString("\t")) + println(List("count_included", "count_excluded").mkString("\t")) + println(List(incl, excl).mkString("\t")) } finally { - templateBam.close() - targetBam.close() - filteredBam.foreach(x => x.close()) + inBam.close() + outBam.close() + filteredOutBam.foreach(x => x.close()) } } @@ -391,6 +384,7 @@ object WipeReads extends ToolCommand { val commandArgs: Args = parseArgs(args) + // cannot use SamReader as inBam directly since it only allows one active iterator at any given time val filterFunc = makeFilterNotFunction( ivl = makeIntervalFromFile(commandArgs.targetRegions), inBam = commandArgs.inputBam, @@ -403,10 +397,9 @@ object WipeReads extends ToolCommand { writeFilteredBam( filterFunc, - commandArgs.inputBam, - commandArgs.outputBam, - writeIndex = !commandArgs.noMakeIndex, - filteredOutBam = commandArgs.filteredOutBam + prepInBam(commandArgs.inputBam), + prepOutBam(commandArgs.outputBam, commandArgs.inputBam, writeIndex = !commandArgs.noMakeIndex), + commandArgs.filteredOutBam.map(x => prepOutBam(x, commandArgs.inputBam, writeIndex = !commandArgs.noMakeIndex)) ) } } diff --git a/biopet-framework/src/test/scala/nl/lumc/sasc/biopet/tools/WipeReadsUnitTest.scala b/biopet-framework/src/test/scala/nl/lumc/sasc/biopet/tools/WipeReadsUnitTest.scala index 15ec3a3c4..772ed6295 100644 --- a/biopet-framework/src/test/scala/nl/lumc/sasc/biopet/tools/WipeReadsUnitTest.scala +++ b/biopet-framework/src/test/scala/nl/lumc/sasc/biopet/tools/WipeReadsUnitTest.scala @@ -9,6 +9,7 @@ import java.nio.file.Paths import scala.collection.JavaConverters._ import htsjdk.samtools.SAMFileHeader +import htsjdk.samtools.SAMFileWriter import htsjdk.samtools.SAMLineParser import htsjdk.samtools.SAMReadGroupRecord import htsjdk.samtools.SAMRecord @@ -18,10 +19,13 @@ import htsjdk.samtools.SamReaderFactory import htsjdk.samtools.ValidationStringency import htsjdk.samtools.util.Interval import org.scalatest.Matchers +import org.mockito.Matchers._ +import org.mockito.Mockito.{ inOrder => inOrd, times, verify } +import org.scalatest.mock.MockitoSugar import org.scalatest.testng.TestNGSuite import org.testng.annotations.Test -class WipeReadsUnitTest extends TestNGSuite with Matchers { +class WipeReadsUnitTest extends TestNGSuite with MockitoSugar with Matchers { import WipeReads._ @@ -371,82 +375,67 @@ class WipeReadsUnitTest extends TestNGSuite with Matchers { filterNotFunc(pBamRecs2(8)) shouldBe false filterNotFunc(pBamRecs2(9)) shouldBe false } - @Test def testWriteSingleBamDefault() = { val mockFilterOutFunc = (r: SAMRecord) => Set("r03", "r04", "r05").contains(r.getReadName) - val outBam = makeTempBam() - val outBamIndex = makeTempBamIndex(outBam) - outBam.deleteOnExit() - outBamIndex.deleteOnExit() + val outBam = mock[SAMFileWriter] val stdout = new java.io.ByteArrayOutputStream Console.withOut(stdout) { - writeFilteredBam(mockFilterOutFunc, sBamFile1, outBam) + writeFilteredBam(mockFilterOutFunc, makeSamReader(sBamFile1), outBam) } stdout.toString should ===( - "input_bam\toutput_bam\tcount_included\tcount_excluded\n%s\t%s\t%d\t%d\n" - .format(sBamFile1.getName, outBam.getName, 4, 3) + "count_included\tcount_excluded\n%d\t%d\n" + .format(4, 3) ) - val exp = makeSamReader(sBamFile3).asScala - val obs = makeSamReader(outBam).asScala - for ((e, o) <- exp.zip(obs)) - e.getSAMString should ===(o.getSAMString) - outBam should be('exists) - outBamIndex should be('exists) + val exp = makeSamReader(sBamFile3).asScala.toList + verify(outBam, times(4)).addAlignment(anyObject.asInstanceOf[SAMRecord]) + val obs = inOrd(outBam) + exp.foreach(x => { + obs.verify(outBam).addAlignment(x) + }) } @Test def testWriteSingleBamAndFilteredBAM() = { val mockFilterOutFunc = (r: SAMRecord) => Set("r03", "r04", "r05").contains(r.getReadName) - val outBam = makeTempBam() - val outBamIndex = makeTempBamIndex(outBam) - outBam.deleteOnExit() - outBamIndex.deleteOnExit() - val filteredOutBam = makeTempBam() - val filteredOutBamIndex = makeTempBamIndex(filteredOutBam) - filteredOutBam.deleteOnExit() - filteredOutBamIndex.deleteOnExit() + val outBam = mock[SAMFileWriter] + val filtBam = Some(mock[SAMFileWriter]) val stdout = new java.io.ByteArrayOutputStream Console.withOut(stdout) { - writeFilteredBam(mockFilterOutFunc, sBamFile1, outBam, filteredOutBam = Some(filteredOutBam)) + writeFilteredBam(mockFilterOutFunc, makeSamReader(sBamFile1), outBam, filteredOutBam = filtBam) } stdout.toString should ===( - "input_bam\toutput_bam\tcount_included\tcount_excluded\n%s\t%s\t%d\t%d\n" - .format(sBamFile1.getName, outBam.getName, 4, 3) + "count_included\tcount_excluded\n%d\t%d\n" + .format(4, 3) ) val exp = makeSamReader(sBamFile4).asScala - val obs = makeSamReader(filteredOutBam).asScala - for ((e, o) <- exp.zip(obs)) - e.getSAMString should ===(o.getSAMString) - outBam should be('exists) - outBamIndex should be('exists) - filteredOutBam should be('exists) - filteredOutBamIndex should be('exists) + verify(filtBam.get, times(3)).addAlignment(anyObject.asInstanceOf[SAMRecord]) + val obs = inOrd(filtBam.get) + exp.foreach(x => { + obs.verify(filtBam.get).addAlignment(x) + }) } @Test def testWritePairBamDefault() = { val mockFilterOutFunc = (r: SAMRecord) => Set("r03", "r04", "r05").contains(r.getReadName) - val outBam = makeTempBam() - val outBamIndex = makeTempBamIndex(outBam) - outBam.deleteOnExit() - outBamIndex.deleteOnExit() + val outBam = mock[SAMFileWriter] val stdout = new java.io.ByteArrayOutputStream Console.withOut(stdout) { - writeFilteredBam(mockFilterOutFunc, pBamFile1, outBam) + writeFilteredBam(mockFilterOutFunc, makeSamReader(pBamFile1), outBam) } stdout.toString should ===( - "input_bam\toutput_bam\tcount_included\tcount_excluded\n%s\t%s\t%d\t%d\n" - .format(pBamFile1.getName, outBam.getName, 8, 6) + "count_included\tcount_excluded\n%d\t%d\n" + .format(8, 6) ) - val exp = makeSamReader(pBamFile3).asScala - val obs = makeSamReader(outBam).asScala - for ((e, o) <- exp.zip(obs)) - e.getSAMString should ===(o.getSAMString) - outBam should be('exists) - outBamIndex should be('exists) + val exp = makeSamReader(pBamFile3).asScala.toList + verify(outBam, times(8)).addAlignment(anyObject.asInstanceOf[SAMRecord]) + val obs = inOrd(outBam) + exp.foreach(x => { + obs.verify(outBam).addAlignment(x) + }) } @Test def testArgsMinimum() = { -- GitLab