WipeReads.scala 19.2 KB
Newer Older
bow's avatar
bow committed
1
2
/**
 * Copyright (c) 2014 Leiden University Medical Center - Sequencing Analysis Support Core <sasc@lumc.nl>
3
 * @author Wibowo Arindrarto <w.arindrarto@lumc.nl>
bow's avatar
bow committed
4
5
6
 */
package nl.lumc.sasc.biopet.core.apps

7
import java.io.{ File, IOException }
bow's avatar
bow committed
8
import scala.collection.JavaConverters._
9
import scala.io.Source
bow's avatar
bow committed
10

11
import com.twitter.algebird.{ BF, BloomFilter, BloomFilterMonoid }
12
import htsjdk.samtools.AlignmentBlock
13
14
import htsjdk.samtools.SAMFileReader
import htsjdk.samtools.SAMFileReader.QueryInterval
15
import htsjdk.samtools.SAMFileWriterFactory
16
import htsjdk.samtools.SAMRecord
17
import htsjdk.tribble.index.interval.{ Interval, IntervalTree }
18
import org.apache.commons.io.FilenameUtils.getExtension
bow's avatar
bow committed
19
20
21
import org.broadinstitute.gatk.utils.commandline.{ Input, Output }

import nl.lumc.sasc.biopet.core.BiopetJavaCommandLineFunction
bow's avatar
bow committed
22
import nl.lumc.sasc.biopet.core.MainCommand
bow's avatar
bow committed
23
24
import nl.lumc.sasc.biopet.core.config.Configurable

bow's avatar
bow committed
25
// TODO: finish implementation for usage in pipelines
26
27
28
29
30
/**
 * WipeReads function class for usage in Biopet pipelines
 *
 * @param root Configuration object for the pipeline
 */
bow's avatar
bow committed
31
32
33
34
35
36
37
38
39
40
41
42
class WipeReads(val root: Configurable) extends BiopetJavaCommandLineFunction {

  javaMainClass = getClass.getName

  @Input(doc = "Input BAM file (must be indexed)", shortName = "I", required = true)
  var inputBAM: File = _

  @Output(doc = "Output BAM", shortName = "o", required = true)
  var outputBAM: File = _

}

bow's avatar
bow committed
43
object WipeReads extends MainCommand {
bow's avatar
bow committed
44

bow's avatar
bow committed
45
  /** Container type for command line flags */
bow's avatar
bow committed
46
47
  type OptionMap = Map[String, Any]

bow's avatar
bow committed
48
  /** Container class for interval parsing results */
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
  case class RawInterval(chrom: String, start: Int, end: Int) {

    require(start <= end, s"Start coordinate $start is larger than end coordinate $end")

    /** Function to check whether one interval contains the other */
    def contains(that: RawInterval): Boolean =
      if (this.chrom != that.chrom)
        false
      else
        this.start <= that.start && this.end >= that.end

    /** Function to check whether two intervals overlap each other */
    def overlaps(that: RawInterval): Boolean =
      if (this.chrom != that.chrom)
        false
      else
        this.start <= that.start && this.end >= that.start

    /** Function to merge two overlapping intervals */
    def merge(that: RawInterval): RawInterval = {
      if (this.chrom != that.chrom)
        throw new IllegalArgumentException("Can not merge RawInterval objects from different chromosomes")
      if (contains(that))
        this
      else if (overlaps(that))
        RawInterval(this.chrom, this.start, that.end)
      else
        throw new IllegalArgumentException("Can not merge non-overlapping RawInterval objects")
    }
78

79
  }
bow's avatar
bow committed
80

81
82
83
84
85
86
87
88
89
90
91
  def mergeOverlappingIntervals(ri: Iterator[RawInterval]): Iterator[RawInterval] =
    ri.toList
      .sortBy(x => (x.chrom, x.start, x.end))
      .foldLeft(List.empty[RawInterval]) {
        (acc, i) => acc match {
          case head :: tail if head.overlaps(i) => head.merge(i) :: tail
          case _                                => i :: acc

        }}
      .toIterator

92
93
94
95
96
  /**
   * Function to create iterator over intervals from input interval file
   *
   * @param inFile input interval file
   */
97
  def makeRawIntervalFromFile(inFile: File): Iterator[RawInterval] = {
98

bow's avatar
bow committed
99
    /** Function to create iterator from BED file */
100
    def makeRawIntervalFromBED(inFile: File): Iterator[RawInterval] =
101
      // BED file coordinates are 0-based, half open so we need to do some conversion
102
103
104
105
106
      Source.fromFile(inFile)
        .getLines()
        .filterNot(_.trim.isEmpty)
        .dropWhile(_.matches("^track | ^browser "))
        .map(line => line.trim.split("\t") match {
bow's avatar
bow committed
107
          case Array(chrom, start, end, _*) => new RawInterval(chrom, start.toInt + 1, end.toInt)
108
        })
109

bow's avatar
bow committed
110
111
    // TODO: implementation
    /** Function to create iterator from refFlat file */
112
113
114
    def makeRawIntervalFromRefFlat(inFile: File): Iterator[RawInterval] = ???
    // convert coordinate to 1-based fully closed
    // parse chrom, start blocks, end blocks, strands
115

bow's avatar
bow committed
116
117
    // TODO: implementation
    /** Function to create iterator from GTF file */
118
119
120
    def makeRawIntervalFromGTF(inFile: File): Iterator[RawInterval] = ???
    // convert coordinate to 1-based fully closed
    // parse chrom, start blocks, end blocks, strands
121
122
123
124
125
126
127
128

    // detect interval file format from extension
    val iterFunc: (File => Iterator[RawInterval]) =
      if (getExtension(inFile.toString.toLowerCase) == "bed")
        makeRawIntervalFromBED
      else
        throw new IllegalArgumentException("Unexpected interval file type: " + inFile.getPath)

129
    mergeOverlappingIntervals(iterFunc(inFile))
130
131
  }

132
  // TODO: set minimum fraction for overlap
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
  /**
   * Function to create function to check SAMRecord for exclusion in filtered BAM file.
   *
   * The returned function evaluates all filtered-in SAMRecord to false.
   *
   * @param iv iterator yielding RawInterval objects
   * @param inBAM input BAM file
   * @param inBAMIndex input BAM file index
   * @param filterOutMulti whether to filter out reads with same name outside target region (default: true)
   * @param minMapQ minimum MapQ of reads in target region to filter out (default: 0)
   * @param readGroupIDs read group IDs of reads in target region to filter out (default: all IDs)
   * @param bloomSize expected size of elements to contain in the Bloom filter
   * @param bloomFp expected Bloom filter false positive rate
   * @return function that checks whether a SAMRecord or String is to be excluded
   */
148
149
150
151
  def makeFilterOutFunction(iv: Iterator[RawInterval],
                            inBAM: File, inBAMIndex: File = null,
                            filterOutMulti: Boolean = true,
                            minMapQ: Int = 0, readGroupIDs: Set[String] = Set(),
152
                            bloomSize: Int = 100000000, bloomFp: Double = 1e-10): (SAMRecord => Boolean) = {
153
154

    // TODO: implement optional index creation
bow's avatar
bow committed
155
    /** Function to check for BAM file index and return a SAMFileReader given a File */
156
157
158
    def prepIndexedInputBAM(): SAMFileReader =
      if (inBAMIndex != null)
        new SAMFileReader(inBAM, inBAMIndex)
159
      else {
160
        val sfr = new SAMFileReader(inBAM)
161
162
163
164
165
166
        if (!sfr.hasIndex)
          throw new IllegalStateException("Input BAM file must be indexed")
        else
          sfr
      }

167
168
169
170
171
172
173
174
175
176
    /**
     * Function to query intervals from a BAM file
     *
     * The function still works when only either one of the interval or BAM
     * file contig is prepended with "chr"
     *
     * @param inBAM BAM file to query as SAMFileReader
     * @param ri raw interval object containing query
     * @return QueryInterval wrapped in Option
     */
177
178
179
180
181
182
183
184
185
186
187
188
    def monadicMakeQueryInterval(inBAM: SAMFileReader, ri: RawInterval): Option[QueryInterval] =
      if (inBAM.getFileHeader.getSequenceIndex(ri.chrom) > -1)
        Some(inBAM.makeQueryInterval(ri.chrom, ri.start, ri.end))
      else if (ri.chrom.startsWith("chr")
        && inBAM.getFileHeader.getSequenceIndex(ri.chrom.substring(3)) > -1)
        Some(inBAM.makeQueryInterval(ri.chrom.substring(3), ri.start, ri.end))
      else if (!ri.chrom.startsWith("chr")
        && inBAM.getFileHeader.getSequenceIndex("chr" + ri.chrom) > -1)
        Some(inBAM.makeQueryInterval("chr" + ri.chrom, ri.start, ri.end))
      else
        None

189
    // TODO: can we accumulate errors / exceptions instead of ignoring them?
190
191
192
193
194
195
196
    /**
     * Function to query mate from a BAM file
     *
     * @param inBAM BAM file to query as SAMFileReader
     * @param rec SAMRecord whose mate will be queried
     * @return SAMRecord wrapped in Option
     */
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    def monadicMateQuery(inBAM: SAMFileReader, rec: SAMRecord): Option[SAMRecord] = {
      // adapted from htsjdk's SAMFileReader.queryMate to better deal with multiple-mapped reads
      if (!rec.getReadPairedFlag)
        return None

      if (rec.getFirstOfPairFlag == rec.getSecondOfPairFlag)
        throw new IllegalArgumentException("SAMRecord must either be first or second of pair, but not both")

      val it =
        if (rec.getMateReferenceIndex == SAMRecord.NO_ALIGNMENT_REFERENCE_INDEX)
          inBAM.queryUnmapped()
        else
          inBAM.queryAlignmentStart(rec.getMateReferenceName, rec.getMateAlignmentStart)

      val qres =
        try {
          it.asScala.toList
            .filter(x => x.getReadPairedFlag
                         && rec.getAlignmentStart == x.getMateAlignmentStart
                         && rec.getMateAlignmentStart == x.getAlignmentStart)
        } finally {
          it.close()
219
        }
220
221
222
223
224
225

      if (qres.length == 1)
        Some(qres.head)
      else
        None
    }
226

227
228
229
230
231
232
233
234
235
236
237
238
    /** function to make IntervalTree from our RawInterval objects
      *
      * @param ri iterable over RawInterval objects
      * @return IntervalTree objects, filled with intervals from RawInterval
      */
    def makeIntervalTree(ri: Iterable[RawInterval]): IntervalTree = {
      val ivt = new IntervalTree
      for (iv: RawInterval <- ri)
        ivt.insert(new Interval(iv.start, iv.end))
      ivt
    }

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    /**
     * Function to ensure that a SAMRecord overlaps our target regions
     *
     * This is required because htsjdk's queryOverlap method does not take into
     * account the SAMRecord splicing structure
     *
     * @param rec SAMRecord to check
     * @param ivtm mutable mapping of a chromosome and its interval tree
     * @return
     */
    def alignmentBlockOverlaps(rec: SAMRecord, ivtm: Map[String, IntervalTree]): Boolean =
      // if SAMRecord is not spliced, assume queryOverlap has done its job
      // otherwise check for alignment block overlaps in our interval list
      // using raw SAMString to bypass cigar string decoding
      if (rec.getSAMString.split("\t")(5).contains("N")) {
        for (ab: AlignmentBlock <- rec.getAlignmentBlocks.asScala) {
          if (!ivtm(rec.getReferenceName).findOverlapping(
256
            new Interval(ab.getReferenceStart, ab.getReferenceStart + ab.getLength - 1)).isEmpty)
257
258
259
260
261
262
            return true
        }
        false
      } else
        true

bow's avatar
bow committed
263
    /** filter function for read IDs */
264
265
266
267
268
269
    val rgFilter =
      if (readGroupIDs.size == 0)
        (r: SAMRecord) => true
      else
        (r: SAMRecord) => readGroupIDs.contains(r.getReadGroup.getReadGroupId)

bow's avatar
bow committed
270
    /** function to get set element */
271
272
273
274
    val SAMToElem =
      if (filterOutMulti)
        (r: SAMRecord) => r.getReadName
      else
275
        (r: SAMRecord) => r.getReadName + "_" + r.getAlignmentStart
276

277
278
279
    val firstBAM = prepIndexedInputBAM()
    val secondBAM = prepIndexedInputBAM()
    val bfm = BloomFilter(bloomSize, bloomFp, 13)
280

281
282
283
284
285
286
287
288
289
290
291
292
    /** Function to make a BloomFilter containing one element from the SAMRecord */
    def makeBFFromSAM(rec: SAMRecord, bfm: BloomFilterMonoid): BF = {
      if (filterOutMulti)
        bfm.create(rec.getReadName)
      else if (!rec.getReadPairedFlag)
        bfm.create(SAMToElem(rec))
      else
      // to bypass querying for each mate, we store the records that the mate also has
      // namely, the read name and the alignment start
        bfm.create(SAMToElem(rec), rec.getReadName + "_" + rec.getMateAlignmentStart)
    }

293
294
295
296
297
298
299
300
301
302
    /* NOTE: the interval vector here should be bypass-able if we can make
             the BAM query intervals with Interval objects. This is not possible
             at the moment since we can not retrieve star and end coordinates
             of an Interval, so we resort to our own RawInterval vector
    */
    lazy val intervals = iv.toVector
    lazy val intervalTreeMap: Map[String, IntervalTree] = intervals
      .groupBy(x => x.chrom)
      .map({ case (key, value) => (key, makeIntervalTree(value)) })
    lazy val queryIntervals = intervals
303
304
305
306
      .flatMap(x => monadicMakeQueryInterval(firstBAM, x))
      // makeQueryInterval only accepts a sorted QUeryInterval list
      .sortBy(x => (x.referenceIndex, x.start, x.end))
      .toArray
307
308
309

    val filteredOutSet: BF = firstBAM.queryOverlapping(queryIntervals).asScala
      // ensure spliced reads have at least one block overlapping target region
310
      .filter(x => alignmentBlockOverlaps(x, intervalTreeMap))
311
312
313
314
      // filter for MAPQ on target region reads
      .filter(x => x.getMappingQuality >= minMapQ)
      // filter on specific read group IDs
      .filter(x => rgFilter(x))
315
      // transform SAMRecord to string
316
      .map(x => makeBFFromSAM(x, bfm))
317
      // build bloom filter using fold to prevent loading all strings to memory
318
      .foldLeft(bfm.create())(_.++(_))
319
320

    if (filterOutMulti)
321
      (rec: SAMRecord) => filteredOutSet.contains(rec.getReadName).isTrue
322
    else
323
324
      (rec: SAMRecord) => {
        if (rec.getReadPairedFlag)
325
326
          filteredOutSet.contains(rec.getReadName + "_" + rec.getAlignmentStart).isTrue &&
          filteredOutSet.contains(rec.getReadName + "_" + rec.getMateAlignmentStart).isTrue
327
328
        else
          filteredOutSet.contains(rec.getReadName + "_" + rec.getAlignmentStart).isTrue
329
      }
330
331
  }

bow's avatar
bow committed
332
  // TODO: implement stats file output
333
334
335
336
337
338
339
340
341
342
  /**
   * Function to filter input BAM and write its output to the filesystem
   *
   * @param filterFunc filter function that evaluates true for excluded SAMRecord
   * @param inBAM input BAM file
   * @param outBAM output BAM file
   * @param writeIndex whether to write index for output BAM file
   * @param async whether to write asynchronously
   * @param filteredOutBAM whether to write excluded SAMRecords to their own BAM file
   */
343
344
345
346
347
348
349
350
351
352
353
  def writeFilteredBAM(filterFunc: (SAMRecord => Boolean), inBAM: File, outBAM: File,
                       writeIndex: Boolean = true, async: Boolean = true,
                       filteredOutBAM: File = null) = {

    val factory = new SAMFileWriterFactory()
      .setCreateIndex(writeIndex)
      .setUseAsyncIo(async)
    val templateBAM = new SAMFileReader(inBAM)
    val targetBAM = factory.makeBAMWriter(templateBAM.getFileHeader, true, outBAM)
    val filteredBAM =
      if (filteredOutBAM != null)
354
        factory.makeBAMWriter(templateBAM.getFileHeader, true, filteredOutBAM)
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
      else
        null

    try {
      for (rec: SAMRecord <- templateBAM.asScala) {
        if (!filterFunc(rec)) targetBAM.addAlignment(rec)
        else if (filteredBAM != null) filteredBAM.addAlignment(rec)
      }
    } finally {
      templateBAM.close()
      targetBAM.close()
      if (filteredBAM != null) filteredBAM.close()
    }
  }

370
371
372
373
374
375
376
  /**
   * Recursive function to parse command line options
   *
   * @param opts OptionMap instance which may contain parsed options
   * @param list remaining command line arguments
   * @return OptionMap instance
   */
377
  def parseOption(opts: OptionMap, list: List[String]): OptionMap =
378
    // format: OFF
379
380
381
382
383
384
385
386
387
388
389
390
391
    list match {
      case Nil
          => opts
      case ("--inputBAM" | "-I") :: value :: tail if !opts.contains("inputBAM")
          => parseOption(opts ++ Map("inputBAM" -> checkInputBAM(new File(value))), tail)
      case ("--targetRegions" | "-l") :: value :: tail if !opts.contains("targetRegions")
          => parseOption(opts ++ Map("targetRegions" -> checkInputFile(new File(value))), tail)
      case ("--outputBAM" | "-o") :: value :: tail if !opts.contains("outputBAM")
          => parseOption(opts ++ Map("outputBAM" -> new File(value)), tail)
      case ("--minMapQ" | "-Q") :: value :: tail if !opts.contains("minMapQ")
          => parseOption(opts ++ Map("minMapQ" -> value.toInt), tail)
      // TODO: better way to parse multiple flag values?
      case ("--readGroup" | "-RG") :: value :: tail if !opts.contains("readGroup")
392
393
394
395
396
397
398
399
      => parseOption(opts ++ Map("readGroup" -> value.split(",").toSet), tail)
      case ("--noMakeIndex") :: tail
          => parseOption(opts ++ Map("noMakeIndex" -> true), tail)
      case ("--limitToRegion" | "-limit") :: tail
          => parseOption(opts ++ Map("limitToRegion" -> true), tail)
      // TODO: implementation
      case ("--minOverlapFraction" | "-f") :: value :: tail if !opts.contains("minOverlapFraction")
      => parseOption(opts ++ Map("minOverlapFraction" -> value.toDouble), tail)
400
401
      case option :: tail
          => throw new IllegalArgumentException("Unexpected or duplicate option flag: " + option)
bow's avatar
bow committed
402
    }
403
  // format: ON
bow's avatar
bow committed
404

bow's avatar
bow committed
405
  /** Function to validate OptionMap instances */
406
  private def validateOption(opts: OptionMap): Unit = {
407
408
409
410
411
412
    // TODO: better way to check for required arguments ~ use scalaz.Validation?
    if (opts.get("inputBAM") == None)
      throw new IllegalArgumentException("Input BAM not supplied")
    if (opts.get("targetRegions") == None)
      throw new IllegalArgumentException("Target regions file not supplied")
  }
bow's avatar
bow committed
413

bow's avatar
bow committed
414
  /** Function that returns the given File if it exists */
415
416
417
418
419
420
  def checkInputFile(inFile: File): File =
    if (inFile.exists)
      inFile
    else
      throw new IOException("Input file " + inFile.getPath + " not found")

bow's avatar
bow committed
421
  /** Function that returns the given BAM file if it exists and is indexed */
422
423
  def checkInputBAM(inBAM: File): File = {
    // input BAM must have a .bam.bai index
424
    if (new File(inBAM.getPath + ".bai").exists || new File(inBAM.getPath + ".bam.bai").exists)
425
426
427
428
429
      checkInputFile(inBAM)
    else
      throw new IOException("Index for input BAM file " + inBAM.getPath + " not found")
  }

430
  def main(args: Array[String]): Unit = {
bow's avatar
bow committed
431

432
433
434
    if (args.length == 0) {
      println(usage)
      System.exit(1)
bow's avatar
bow committed
435
    }
436
437
    val options = parseOption(Map(), args.toList)
    validateOption(options)
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452

    val inputBAM = options("inputBAM").asInstanceOf[File]
    val outputBAM = options("outputBAM").asInstanceOf[File]

    val iv = makeRawIntervalFromFile(options("targetRegions").asInstanceOf[File])
    // limiting bloomSize to 70M and bloomFp to 4e-7 due to Int size limitation set in algebird
    val filterFunc = makeFilterOutFunction(iv = iv,
      inBAM = inputBAM,
      filterOutMulti = !options.getOrElse("limitToRegion", false).asInstanceOf[Boolean],
      minMapQ = options.getOrElse("minMapQ", 0).asInstanceOf[Int],
      readGroupIDs = options.getOrElse("readGroupIDs", Set()).asInstanceOf[Set[String]],
      bloomSize = 70000000, bloomFp = 4e-7)

    writeFilteredBAM(filterFunc, inputBAM, outputBAM,
      writeIndex = !options.getOrElse("noMakeIndex", false).asInstanceOf[Boolean])
bow's avatar
bow committed
453
454
  }

455
  val usage: String =
bow's avatar
bow committed
456
    s"""
bow's avatar
bow committed
457
      |Usage: java -jar BiopetFramework.jar tool $name [options] -I input -l regions -o output
458
      |
bow's avatar
bow committed
459
      |$name - Tool for reads removal from an indexed BAM file
460
      |
461
462
463
      |Positional arguments:
      |  -I,--inputBAM              Input BAM file, must be indexed with
      |                             '.bam.bai' or 'bai' extension
464
465
466
      |  -l,--targetRegions         Input BED file
      |  -o,--outputBAM             Output BAM file
      |
467
468
469
      |Optional arguments:
      |  -RG,--readGroup            Read groups to remove; set multiple read
      |                             groups using commas (default: all)
470
471
472
473
      |  -Q,--minMapQ               Minimum MAPQ value of reads in target region
      |                             (default: 0)
      |  --makeIndex                Write BAM output file index
      |                             (default: true)
474
475
476
      |  --limitToRegion            Whether to remove only reads in the target
      |                             regions and and keep the same reads if they
      |                             map to other regions (default: not set)
477
478
479
480
      |
      |This tool will remove BAM records that overlaps a set of given regions.
      |By default, if the removed reads are also mapped to other regions outside
      |the given ones, they will also be removed.
bow's avatar
bow committed
481
    """.stripMargin
482
}