Skip to content
Snippets Groups Projects
Commit 879c66bb authored by bow's avatar bow
Browse files

Use mocks to avoid writing to disk when testing

parent 118f09c8
No related branches found
No related tags found
No related merge requests found
......@@ -13,6 +13,7 @@ import htsjdk.samtools.SamReader
import htsjdk.samtools.SamReaderFactory
import htsjdk.samtools.QueryInterval
import htsjdk.samtools.ValidationStringency
import htsjdk.samtools.SAMFileWriter
import htsjdk.samtools.SAMFileWriterFactory
import htsjdk.samtools.SAMRecord
import htsjdk.samtools.util.{ Interval, IntervalTreeMap }
......@@ -51,7 +52,7 @@ object WipeReads extends ToolCommand {
* @param inBam input BAM file
* @return
*/
private def prepBam(inBam: File): SamReader = {
private def prepInBam(inBam: File): SamReader = {
val bam = SamReaderFactory
.make()
.validationStringency(ValidationStringency.LENIENT)
......@@ -60,6 +61,13 @@ object WipeReads extends ToolCommand {
bam
}
private def prepOutBam(outBam: File, templateBam: File,
writeIndex: Boolean = true, async: Boolean = true): SAMFileWriter =
new SAMFileWriterFactory()
.setCreateIndex(writeIndex)
.setUseAsyncIo(async)
.makeBAMWriter(prepInBam(templateBam).getFileHeader, true, outBam)
/**
* Creates a list of intervals given an input File
*
......@@ -208,7 +216,7 @@ object WipeReads extends ToolCommand {
else
(r: SAMRecord) => readGroupIds.contains(r.getReadGroup.getReadGroupId)
val readyBam = prepBam(inBam)
val readyBam = prepInBam(inBam)
val queryIntervals = ivl
.flatMap(x => makeQueryInterval(readyBam, x))
......@@ -261,44 +269,29 @@ object WipeReads extends ToolCommand {
* @param filterFunc filter function that evaluates true for excluded SAMRecord
* @param inBam input BAM file
* @param outBam output BAM file
* @param writeIndex whether to write index for output BAM file
* @param async whether to write asynchronously
* @param filteredOutBam whether to write excluded SAMRecords to their own BAM file
*/
def writeFilteredBam(filterFunc: (SAMRecord => Boolean), inBam: File, outBam: File,
writeIndex: Boolean = true, async: Boolean = true,
filteredOutBam: Option[File] = None) = {
val factory = new SAMFileWriterFactory()
.setCreateIndex(writeIndex)
.setUseAsyncIo(async)
val templateBam = prepBam(inBam)
val targetBam = factory.makeBAMWriter(templateBam.getFileHeader, true, outBam)
val filteredBam = filteredOutBam
.map(x => factory.makeBAMWriter(templateBam.getFileHeader, true, x))
def writeFilteredBam(filterFunc: (SAMRecord => Boolean), inBam: SamReader, outBam: SAMFileWriter,
filteredOutBam: Option[SAMFileWriter] = None) = {
logger.info("Writing output file(s) ...")
try {
var (incl, excl) = (0, 0)
for (rec <- templateBam.asScala) {
for (rec <- inBam.asScala) {
if (!filterFunc(rec)) {
targetBam.addAlignment(rec)
outBam.addAlignment(rec)
incl += 1
} else {
excl += 1
filteredBam match {
case None =>
case Some(x) => x.addAlignment(rec)
}
filteredOutBam.foreach(x => x.addAlignment(rec))
}
}
println(List("input_bam", "output_bam", "count_included", "count_excluded").mkString("\t"))
println(List(inBam.getName, outBam.getName, incl, excl).mkString("\t"))
println(List("count_included", "count_excluded").mkString("\t"))
println(List(incl, excl).mkString("\t"))
} finally {
templateBam.close()
targetBam.close()
filteredBam.foreach(x => x.close())
inBam.close()
outBam.close()
filteredOutBam.foreach(x => x.close())
}
}
......@@ -391,6 +384,7 @@ object WipeReads extends ToolCommand {
val commandArgs: Args = parseArgs(args)
// cannot use SamReader as inBam directly since it only allows one active iterator at any given time
val filterFunc = makeFilterNotFunction(
ivl = makeIntervalFromFile(commandArgs.targetRegions),
inBam = commandArgs.inputBam,
......@@ -403,10 +397,9 @@ object WipeReads extends ToolCommand {
writeFilteredBam(
filterFunc,
commandArgs.inputBam,
commandArgs.outputBam,
writeIndex = !commandArgs.noMakeIndex,
filteredOutBam = commandArgs.filteredOutBam
prepInBam(commandArgs.inputBam),
prepOutBam(commandArgs.outputBam, commandArgs.inputBam, writeIndex = !commandArgs.noMakeIndex),
commandArgs.filteredOutBam.map(x => prepOutBam(x, commandArgs.inputBam, writeIndex = !commandArgs.noMakeIndex))
)
}
}
......@@ -9,6 +9,7 @@ import java.nio.file.Paths
import scala.collection.JavaConverters._
import htsjdk.samtools.SAMFileHeader
import htsjdk.samtools.SAMFileWriter
import htsjdk.samtools.SAMLineParser
import htsjdk.samtools.SAMReadGroupRecord
import htsjdk.samtools.SAMRecord
......@@ -18,10 +19,13 @@ import htsjdk.samtools.SamReaderFactory
import htsjdk.samtools.ValidationStringency
import htsjdk.samtools.util.Interval
import org.scalatest.Matchers
import org.mockito.Matchers._
import org.mockito.Mockito.{ inOrder => inOrd, times, verify }
import org.scalatest.mock.MockitoSugar
import org.scalatest.testng.TestNGSuite
import org.testng.annotations.Test
class WipeReadsUnitTest extends TestNGSuite with Matchers {
class WipeReadsUnitTest extends TestNGSuite with MockitoSugar with Matchers {
import WipeReads._
......@@ -371,82 +375,67 @@ class WipeReadsUnitTest extends TestNGSuite with Matchers {
filterNotFunc(pBamRecs2(8)) shouldBe false
filterNotFunc(pBamRecs2(9)) shouldBe false
}
@Test def testWriteSingleBamDefault() = {
val mockFilterOutFunc = (r: SAMRecord) => Set("r03", "r04", "r05").contains(r.getReadName)
val outBam = makeTempBam()
val outBamIndex = makeTempBamIndex(outBam)
outBam.deleteOnExit()
outBamIndex.deleteOnExit()
val outBam = mock[SAMFileWriter]
val stdout = new java.io.ByteArrayOutputStream
Console.withOut(stdout) {
writeFilteredBam(mockFilterOutFunc, sBamFile1, outBam)
writeFilteredBam(mockFilterOutFunc, makeSamReader(sBamFile1), outBam)
}
stdout.toString should ===(
"input_bam\toutput_bam\tcount_included\tcount_excluded\n%s\t%s\t%d\t%d\n"
.format(sBamFile1.getName, outBam.getName, 4, 3)
"count_included\tcount_excluded\n%d\t%d\n"
.format(4, 3)
)
val exp = makeSamReader(sBamFile3).asScala
val obs = makeSamReader(outBam).asScala
for ((e, o) <- exp.zip(obs))
e.getSAMString should ===(o.getSAMString)
outBam should be('exists)
outBamIndex should be('exists)
val exp = makeSamReader(sBamFile3).asScala.toList
verify(outBam, times(4)).addAlignment(anyObject.asInstanceOf[SAMRecord])
val obs = inOrd(outBam)
exp.foreach(x => {
obs.verify(outBam).addAlignment(x)
})
}
@Test def testWriteSingleBamAndFilteredBAM() = {
val mockFilterOutFunc = (r: SAMRecord) => Set("r03", "r04", "r05").contains(r.getReadName)
val outBam = makeTempBam()
val outBamIndex = makeTempBamIndex(outBam)
outBam.deleteOnExit()
outBamIndex.deleteOnExit()
val filteredOutBam = makeTempBam()
val filteredOutBamIndex = makeTempBamIndex(filteredOutBam)
filteredOutBam.deleteOnExit()
filteredOutBamIndex.deleteOnExit()
val outBam = mock[SAMFileWriter]
val filtBam = Some(mock[SAMFileWriter])
val stdout = new java.io.ByteArrayOutputStream
Console.withOut(stdout) {
writeFilteredBam(mockFilterOutFunc, sBamFile1, outBam, filteredOutBam = Some(filteredOutBam))
writeFilteredBam(mockFilterOutFunc, makeSamReader(sBamFile1), outBam, filteredOutBam = filtBam)
}
stdout.toString should ===(
"input_bam\toutput_bam\tcount_included\tcount_excluded\n%s\t%s\t%d\t%d\n"
.format(sBamFile1.getName, outBam.getName, 4, 3)
"count_included\tcount_excluded\n%d\t%d\n"
.format(4, 3)
)
val exp = makeSamReader(sBamFile4).asScala
val obs = makeSamReader(filteredOutBam).asScala
for ((e, o) <- exp.zip(obs))
e.getSAMString should ===(o.getSAMString)
outBam should be('exists)
outBamIndex should be('exists)
filteredOutBam should be('exists)
filteredOutBamIndex should be('exists)
verify(filtBam.get, times(3)).addAlignment(anyObject.asInstanceOf[SAMRecord])
val obs = inOrd(filtBam.get)
exp.foreach(x => {
obs.verify(filtBam.get).addAlignment(x)
})
}
@Test def testWritePairBamDefault() = {
val mockFilterOutFunc = (r: SAMRecord) => Set("r03", "r04", "r05").contains(r.getReadName)
val outBam = makeTempBam()
val outBamIndex = makeTempBamIndex(outBam)
outBam.deleteOnExit()
outBamIndex.deleteOnExit()
val outBam = mock[SAMFileWriter]
val stdout = new java.io.ByteArrayOutputStream
Console.withOut(stdout) {
writeFilteredBam(mockFilterOutFunc, pBamFile1, outBam)
writeFilteredBam(mockFilterOutFunc, makeSamReader(pBamFile1), outBam)
}
stdout.toString should ===(
"input_bam\toutput_bam\tcount_included\tcount_excluded\n%s\t%s\t%d\t%d\n"
.format(pBamFile1.getName, outBam.getName, 8, 6)
"count_included\tcount_excluded\n%d\t%d\n"
.format(8, 6)
)
val exp = makeSamReader(pBamFile3).asScala
val obs = makeSamReader(outBam).asScala
for ((e, o) <- exp.zip(obs))
e.getSAMString should ===(o.getSAMString)
outBam should be('exists)
outBamIndex should be('exists)
val exp = makeSamReader(pBamFile3).asScala.toList
verify(outBam, times(8)).addAlignment(anyObject.asInstanceOf[SAMRecord])
val obs = inOrd(outBam)
exp.foreach(x => {
obs.verify(outBam).addAlignment(x)
})
}
@Test def testArgsMinimum() = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment