WipeReads.scala 16.4 KB
Newer Older
bow's avatar
bow committed
1
2
/**
 * Copyright (c) 2014 Leiden University Medical Center - Sequencing Analysis Support Core <sasc@lumc.nl>
3
 * @author Wibowo Arindrarto <w.arindrarto@lumc.nl>
bow's avatar
bow committed
4
 */
Peter van 't Hof's avatar
Peter van 't Hof committed
5
package nl.lumc.sasc.biopet.tools
bow's avatar
bow committed
6

bow's avatar
bow committed
7
import java.io.File
bow's avatar
bow committed
8
import scala.collection.JavaConverters._
bow's avatar
bow committed
9
import scala.io.Source
bow's avatar
bow committed
10
import scala.math.{ max, min }
bow's avatar
bow committed
11

bow's avatar
bow committed
12
13
14
import com.google.common.hash.{ Funnel, BloomFilter, PrimitiveSink }
import htsjdk.samtools.SamReader
import htsjdk.samtools.SamReaderFactory
Peter van 't Hof's avatar
Peter van 't Hof committed
15
16
import htsjdk.samtools.QueryInterval
import htsjdk.samtools.ValidationStringency
17
import htsjdk.samtools.SAMFileWriter
18
import htsjdk.samtools.SAMFileWriterFactory
19
import htsjdk.samtools.SAMRecord
bow's avatar
bow committed
20
import htsjdk.samtools.util.{ Interval, IntervalTreeMap }
21
22
import htsjdk.tribble.AbstractFeatureReader.getFeatureReader
import htsjdk.tribble.bed.BEDCodec
23
import org.apache.commons.io.FilenameUtils.getExtension
bow's avatar
bow committed
24
25
26
import org.broadinstitute.gatk.utils.commandline.{ Input, Output }

import nl.lumc.sasc.biopet.core.BiopetJavaCommandLineFunction
27
import nl.lumc.sasc.biopet.core.ToolCommand
bow's avatar
bow committed
28
29
import nl.lumc.sasc.biopet.core.config.Configurable

bow's avatar
bow committed
30
// TODO: finish implementation for usage in pipelines
31
32
33
34
35
/**
 * WipeReads function class for usage in Biopet pipelines
 *
 * @param root Configuration object for the pipeline
 */
bow's avatar
bow committed
36
37
38
39
40
class WipeReads(val root: Configurable) extends BiopetJavaCommandLineFunction {

  javaMainClass = getClass.getName

  @Input(doc = "Input BAM file (must be indexed)", shortName = "I", required = true)
bow's avatar
bow committed
41
  var inputBam: File = _
bow's avatar
bow committed
42
43

  @Output(doc = "Output BAM", shortName = "o", required = true)
bow's avatar
bow committed
44
  var outputBam: File = _
bow's avatar
bow committed
45
46
47

}

48
object WipeReads extends ToolCommand {
bow's avatar
bow committed
49

bow's avatar
bow committed
50
  /** Creates a SamReader object from an input BAM file, ensuring it is indexed */
51
  private def prepInBam(inBam: File): SamReader = {
bow's avatar
bow committed
52
53
54
55
56
57
    val bam = SamReaderFactory
      .make()
      .validationStringency(ValidationStringency.LENIENT)
      .open(inBam)
    require(bam.hasIndex)
    bam
58
  }
bow's avatar
bow committed
59

bow's avatar
bow committed
60
  /** Creates a [[SAMFileWriter]] object for writing, indexed */
61
62
63
64
65
66
67
  private def prepOutBam(outBam: File, templateBam: File,
                         writeIndex: Boolean = true, async: Boolean = true): SAMFileWriter =
    new SAMFileWriterFactory()
      .setCreateIndex(writeIndex)
      .setUseAsyncIo(async)
      .makeBAMWriter(prepInBam(templateBam).getFileHeader, true, outBam)

68
  /**
bow's avatar
bow committed
69
   * Creates a list of intervals given an input File
70
71
72
   *
   * @param inFile input interval file
   */
bow's avatar
bow committed
73
  def makeIntervalFromFile(inFile: File, gtfFeatureType: String = "exon"): List[Interval] = {
74

bow's avatar
bow committed
75
76
    logger.info("Parsing interval file ...")

bow's avatar
bow committed
77
    /** Function to create iterator from BED file */
bow's avatar
bow committed
78
79
80
81
    def makeIntervalFromBed(inFile: File): Iterator[Interval] =
      asScalaIteratorConverter(getFeatureReader(inFile.toPath.toString, new BEDCodec(), false).iterator)
        .asScala
        .map(x => new Interval(x.getChr, x.getStart, x.getEnd))
82

bow's avatar
bow committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    /**
     * Parses a refFlat file to yield Interval objects
     *
     * Format description:
     * http://genome.csdb.cn/cgi-bin/hgTables?hgsid=6&hgta_doSchemaDb=hg18&hgta_doSchemaTable=refFlat
     *
     * @param inFile input refFlat file
     */
    def makeIntervalFromRefFlat(inFile: File): Iterator[Interval] =
      Source.fromFile(inFile)
        // read each line
        .getLines()
        // skip all empty lines
        .filterNot(x => x.trim.isEmpty)
        // split per column
        .map(line => line.trim.split("\t"))
        // take chromosome and exonEnds and exonStars
        .map(x => (x(2), x.reverse.take(2)))
        // split starts and ends based on comma
        .map(x => (x._1, x._2.map(y => y.split(","))))
        // zip exonStarts and exonEnds, note the index was reversed because we did .reverse above
        .map(x => (x._1, x._2(1).zip(x._2(0))))
105
106
        // make Intervals, accounting for the fact that refFlat coordinates are 0-based
        .map(x => x._2.map(y => new Interval(x._1, y._1.toInt + 1, y._2.toInt)))
bow's avatar
bow committed
107
108
        // flatten sublist
        .flatten
109

bow's avatar
bow committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    /**
     * Parses a GTF file to yield Interval objects
     *
     * @param inFile input GTF file
     * @return
     */
    def makeIntervalFromGtf(inFile: File): Iterator[Interval] =
      Source.fromFile(inFile)
        // read each line
        .getLines()
        // skip all empty lines
        .filterNot(x => x.trim.isEmpty)
        // skip all UCSC track lines and/or ensembl comment lines
        .dropWhile(x => x.matches("^track | ^browser | ^#"))
        // split to columns
        .map(x => x.split("\t"))
        // exclude intervals whose type is different from the supplied one
        .filter(x => x(2) == gtfFeatureType)
        // and finally create the interval objects
        .map(x => new Interval(x(0), x(3).toInt, x(4).toInt))
130
131

    // detect interval file format from extension
bow's avatar
bow committed
132
    val iterFunc: (File => Iterator[Interval]) =
133
      if (getExtension(inFile.toString.toLowerCase) == "bed")
bow's avatar
bow committed
134
        makeIntervalFromBed
bow's avatar
bow committed
135
136
      else if (getExtension(inFile.toString.toLowerCase) == "refflat")
        makeIntervalFromRefFlat
bow's avatar
bow committed
137
138
      else if (getExtension(inFile.toString.toLowerCase) == "gtf")
        makeIntervalFromGtf
139
140
141
      else
        throw new IllegalArgumentException("Unexpected interval file type: " + inFile.getPath)

bow's avatar
bow committed
142
143
144
145
146
147
148
    iterFunc(inFile).toList
      .sortBy(x => (x.getSequence, x.getStart, x.getEnd))
      .foldLeft(List.empty[Interval])(
        (acc, x) => {
          acc match {
            case head :: tail if x.intersects(head) =>
              new Interval(x.getSequence, min(x.getStart, head.getStart), max(x.getEnd, head.getEnd)) :: tail
bow's avatar
bow committed
149
            case _ => x :: acc
bow's avatar
bow committed
150
151
152
          }
        }
      )
153
154
  }

155
  // TODO: set minimum fraction for overlap
156
157
158
159
160
  /**
   * Function to create function to check SAMRecord for exclusion in filtered BAM file.
   *
   * The returned function evaluates all filtered-in SAMRecord to false.
   *
bow's avatar
bow committed
161
   * @param ivl iterator yielding Feature objects
bow's avatar
bow committed
162
   * @param inBam input BAM file
163
164
   * @param filterOutMulti whether to filter out reads with same name outside target region (default: true)
   * @param minMapQ minimum MapQ of reads in target region to filter out (default: 0)
bow's avatar
bow committed
165
   * @param readGroupIds read group IDs of reads in target region to filter out (default: all IDs)
166
167
168
169
   * @param bloomSize expected size of elements to contain in the Bloom filter
   * @param bloomFp expected Bloom filter false positive rate
   * @return function that checks whether a SAMRecord or String is to be excluded
   */
bow's avatar
bow committed
170
171
  def makeFilterNotFunction(ivl: List[Interval],
                            inBam: File,
172
                            filterOutMulti: Boolean = true,
bow's avatar
bow committed
173
                            minMapQ: Int = 0, readGroupIds: Set[String] = Set(),
174
                            bloomSize: Long, bloomFp: Double): (SAMRecord => Boolean) = {
175

bow's avatar
bow committed
176
177
    logger.info("Building set of reads to exclude ...")

178
    /**
bow's avatar
bow committed
179
     * Creates an Option[QueryInterval] object from the given Interval
180
     *
bow's avatar
bow committed
181
182
183
     * @param in input BAM file
     * @param iv input interval
     * @return
184
     */
bow's avatar
bow committed
185
186
187
188
189
190
191
    def makeQueryInterval(in: SamReader, iv: Interval): Option[QueryInterval] = {
      val getIndex = in.getFileHeader.getSequenceIndex _
      if (getIndex(iv.getSequence) > -1)
        Some(new QueryInterval(getIndex(iv.getSequence), iv.getStart, iv.getEnd))
      else if (iv.getSequence.startsWith("chr") && getIndex(iv.getSequence.substring(3)) > -1) {
        logger.warn("Removing 'chr' prefix from interval " + iv.toString)
        Some(new QueryInterval(getIndex(iv.getSequence.substring(3)), iv.getStart, iv.getEnd))
bow's avatar
bow committed
192
      } else if (!iv.getSequence.startsWith("chr") && getIndex("chr" + iv.getSequence) > -1) {
bow's avatar
bow committed
193
194
        logger.warn("Adding 'chr' prefix to interval " + iv.toString)
        Some(new QueryInterval(getIndex("chr" + iv.getSequence), iv.getStart, iv.getEnd))
bow's avatar
bow committed
195
      } else {
bow's avatar
bow committed
196
        logger.warn("Sequence " + iv.getSequence + " does not exist in alignment")
197
        None
bow's avatar
bow committed
198
      }
199
200
    }

201
202
203
204
205
206
207
208
209
210
    /**
     * Function to ensure that a SAMRecord overlaps our target regions
     *
     * This is required because htsjdk's queryOverlap method does not take into
     * account the SAMRecord splicing structure
     *
     * @param rec SAMRecord to check
     * @param ivtm mutable mapping of a chromosome and its interval tree
     * @return
     */
bow's avatar
bow committed
211
    def alignmentBlockOverlaps(rec: SAMRecord, ivtm: IntervalTreeMap[_]): Boolean =
212
213
214
      // if SAMRecord is not spliced, assume queryOverlap has done its job
      // otherwise check for alignment block overlaps in our interval list
      // using raw SAMString to bypass cigar string decoding
bow's avatar
bow committed
215
216
217
218
219
220
221
222
      if (rec.getSAMString.split("\t")(5).contains("N"))
        rec.getAlignmentBlocks.asScala
          .exists(x =>
            ivtm.containsOverlapping(
              new Interval(rec.getReferenceName,
                x.getReferenceStart, x.getReferenceStart + x.getLength - 1)))
      else
        true
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

    /** function to create a fake SAMRecord pair ~ hack to limit querying BAM file for real pair */
    def makeMockPair(rec: SAMRecord): SAMRecord = {
      require(rec.getReadPairedFlag)
      val fakePair = rec.clone.asInstanceOf[SAMRecord]
      fakePair.setAlignmentStart(rec.getMateAlignmentStart)
      fakePair
    }

    /** function to create set element from SAMRecord */
    def elemFromSam(rec: SAMRecord): String = {
      if (filterOutMulti)
        rec.getReadName
      else
        rec.getReadName + "_" + rec.getAlignmentStart.toString
    }

    /** object for use by BloomFilter */
    object SAMFunnel extends Funnel[SAMRecord] {
      override def funnel(rec: SAMRecord, into: PrimitiveSink): Unit = {
        val elem = elemFromSam(rec)
        logger.debug("Adding " + elem + " to set ...")
        into.putUnencodedChars(elem)
      }
    }
248

bow's avatar
bow committed
249
    /** filter function for read IDs */
250
    val rgFilter =
bow's avatar
bow committed
251
      if (readGroupIds.size == 0)
252
253
        (r: SAMRecord) => true
      else
bow's avatar
bow committed
254
        (r: SAMRecord) => readGroupIds.contains(r.getReadGroup.getReadGroupId)
255

256
    val readyBam = prepInBam(inBam)
257

bow's avatar
bow committed
258
259
260
    val queryIntervals = ivl
      .flatMap(x => makeQueryInterval(readyBam, x))
      // queryOverlapping only accepts a sorted QueryInterval collection ...
261
      .sortBy(x => (x.referenceIndex, x.start, x.end))
bow's avatar
bow committed
262
      // and it has to be an array
263
      .toArray
264

bow's avatar
bow committed
265
266
267
268
269
270
271
    val ivtm: IntervalTreeMap[_] = ivl
      .foldLeft(new IntervalTreeMap[Boolean])(
        (acc, x) => {
          acc.put(x, true)
          acc
        }
      )
bow's avatar
bow committed
272

273
    lazy val filteredOutSet: BloomFilter[SAMRecord] = readyBam
bow's avatar
bow committed
274
275
276
277
      // query BAM file with intervals
      .queryOverlapping(queryIntervals)
      // for compatibility
      .asScala
278
      // ensure spliced reads have at least one block overlapping target region
279
      .filter(x => alignmentBlockOverlaps(x, ivtm))
280
281
282
283
      // filter for MAPQ on target region reads
      .filter(x => x.getMappingQuality >= minMapQ)
      // filter on specific read group IDs
      .filter(x => rgFilter(x))
bow's avatar
bow committed
284
      // fold starting from empty set
285
      .foldLeft(BloomFilter.create(SAMFunnel, bloomSize.toInt, bloomFp)
Peter van 't Hof's avatar
Peter van 't Hof committed
286
287
288
289
290
      )((acc, rec) => {
        acc.put(rec)
        if (rec.getReadPairedFlag) acc.put(makeMockPair(rec))
        acc
      })
291
292

    if (filterOutMulti)
293
      (rec: SAMRecord) => filteredOutSet.mightContain(rec)
294
    else
295
296
      (rec: SAMRecord) => {
        if (rec.getReadPairedFlag)
297
          filteredOutSet.mightContain(rec) && filteredOutSet.mightContain(makeMockPair(rec))
298
        else
299
          filteredOutSet.mightContain(rec)
300
      }
301
302
  }

303
304
305
306
  /**
   * Function to filter input BAM and write its output to the filesystem
   *
   * @param filterFunc filter function that evaluates true for excluded SAMRecord
bow's avatar
bow committed
307
308
309
   * @param inBam input BAM file
   * @param outBam output BAM file
   * @param filteredOutBam whether to write excluded SAMRecords to their own BAM file
310
   */
311
312
  def writeFilteredBam(filterFunc: (SAMRecord => Boolean), inBam: SamReader, outBam: SAMFileWriter,
                       filteredOutBam: Option[SAMFileWriter] = None) = {
313

bow's avatar
bow committed
314
    logger.info("Writing output file(s) ...")
315
    try {
bow's avatar
bow committed
316
      var (incl, excl) = (0, 0)
317
      for (rec <- inBam.asScala) {
bow's avatar
bow committed
318
        if (!filterFunc(rec)) {
319
          outBam.addAlignment(rec)
bow's avatar
bow committed
320
          incl += 1
bow's avatar
bow committed
321
        } else {
bow's avatar
bow committed
322
          excl += 1
323
          filteredOutBam.foreach(x => x.addAlignment(rec))
bow's avatar
bow committed
324
        }
325
      }
326
327
      println(List("count_included", "count_excluded").mkString("\t"))
      println(List(incl, excl).mkString("\t"))
328
    } finally {
329
330
331
      inBam.close()
      outBam.close()
      filteredOutBam.foreach(x => x.close())
332
333
334
    }
  }

bow's avatar
bow committed
335
  /** Default arguments */
bow's avatar
bow committed
336
337
338
339
  case class Args(inputBam: File = new File(""),
                  targetRegions: File = new File(""),
                  outputBam: File = new File(""),
                  filteredOutBam: Option[File] = None,
bow's avatar
bow committed
340
341
342
343
                  readGroupIds: Set[String] = Set.empty[String],
                  minMapQ: Int = 0,
                  limitToRegion: Boolean = false,
                  noMakeIndex: Boolean = false,
bow's avatar
bow committed
344
                  featureType: String = "exon",
bow's avatar
bow committed
345
346
                  bloomSize: Long = 70000000,
                  bloomFp: Double = 4e-7) extends AbstractArgs
347

bow's avatar
bow committed
348
  /** Command line argument parser */
349
350
351
352
353
354
355
  class OptParser extends AbstractOptParser {

    head(
      s"""
        |$commandName - Region-based reads removal from an indexed BAM file
      """.stripMargin)

bow's avatar
bow committed
356
357
358
359
360
361
    opt[File]('I', "input_file") required () valueName "<bam>" action { (x, c) =>
      c.copy(inputBam = x)
    } validate {
      x => if (x.exists) success else failure("Input BAM file not found")
    } text "Input BAM file"

bow's avatar
bow committed
362
    opt[File]('r', "interval_file") required () valueName "<bed/gtf/refflat>" action { (x, c) =>
bow's avatar
bow committed
363
364
365
366
367
368
369
370
371
372
      c.copy(targetRegions = x)
    } validate {
      x => if (x.exists) success else failure("Target regions file not found")
    } text "Interval BED file"

    opt[File]('o', "output_file") required () valueName "<bam>" action { (x, c) =>
      c.copy(outputBam = x)
    } text "Output BAM file"

    opt[File]('f', "discarded_file") optional () valueName "<bam>" action { (x, c) =>
bow's avatar
bow committed
373
      c.copy(filteredOutBam = Some(x))
bow's avatar
bow committed
374
375
376
377
378
379
380
381
382
383
384
385
386
    } text "Discarded reads BAM file (default: none)"

    opt[Int]('Q', "min_mapq") optional () action { (x, c) =>
      c.copy(minMapQ = x)
    } text "Minimum MAPQ of reads in target region to remove (default: 0)"

    opt[String]('G', "read_group") unbounded () optional () valueName "<rgid>" action { (x, c) =>
      c.copy(readGroupIds = c.readGroupIds + x)
    } text "Read group IDs to be removed (default: remove reads from all read groups)"

    opt[Unit]("limit_removal") optional () action { (_, c) =>
      c.copy(limitToRegion = true)
    } text
387
388
      "Whether to remove multiple-mapped reads outside the target regions (default: yes)"

bow's avatar
bow committed
389
390
391
    opt[Unit]("no_make_index") optional () action { (_, c) =>
      c.copy(noMakeIndex = true)
    } text
392
393
      "Whether to index output BAM file or not (default: yes)"

bow's avatar
bow committed
394
395
396
397
398
399
400
    note("\nGTF-only options:")

    opt[String]('t', "feature_type") optional () valueName "<gtf_feature_type>" action { (x, c) =>
      c.copy(featureType = x)
    } text "GTF feature containing intervals (default: exon)"

    note("\nAdvanced options:")
401

bow's avatar
bow committed
402
403
    opt[Long]("bloom_size") optional () action { (x, c) =>
      c.copy(bloomSize = x)
bow's avatar
bow committed
404
    } text "Expected maximum number of reads in target regions (default: 7e7)"
405

bow's avatar
bow committed
406
407
    opt[Double]("false_positive") optional () action { (x, c) =>
      c.copy(bloomFp = x)
bow's avatar
bow committed
408
    } text "False positive rate (default: 4e-7)"
409

410
411
412
413
414
415
    note(
      """
        |This tool will remove BAM records that overlaps a set of given regions.
        |By default, if the removed reads are also mapped to other regions outside
        |the given ones, they will also be removed.
      """.stripMargin)
416

417
418
  }

bow's avatar
bow committed
419
  /** Parses the command line argument */
bow's avatar
bow committed
420
421
422
423
  def parseArgs(args: Array[String]): Args = new OptParser()
    .parse(args, Args())
    .getOrElse(sys.exit(1))

424
  def main(args: Array[String]): Unit = {
bow's avatar
bow committed
425

bow's avatar
bow committed
426
    val commandArgs: Args = parseArgs(args)
427

428
    // cannot use SamReader as inBam directly since it only allows one active iterator at any given time
bow's avatar
bow committed
429
    val filterFunc = makeFilterNotFunction(
bow's avatar
bow committed
430
      ivl = makeIntervalFromFile(commandArgs.targetRegions, gtfFeatureType = commandArgs.featureType),
bow's avatar
bow committed
431
      inBam = commandArgs.inputBam,
432
433
      filterOutMulti = !commandArgs.limitToRegion,
      minMapQ = commandArgs.minMapQ,
bow's avatar
bow committed
434
      readGroupIds = commandArgs.readGroupIds,
435
436
      bloomSize = commandArgs.bloomSize,
      bloomFp = commandArgs.bloomFp
437
438
    )

bow's avatar
bow committed
439
    writeFilteredBam(
440
      filterFunc,
441
442
443
      prepInBam(commandArgs.inputBam),
      prepOutBam(commandArgs.outputBam, commandArgs.inputBam, writeIndex = !commandArgs.noMakeIndex),
      commandArgs.filteredOutBam.map(x => prepOutBam(x, commandArgs.inputBam, writeIndex = !commandArgs.noMakeIndex))
444
    )
bow's avatar
bow committed
445
  }
bow's avatar
bow committed
446
}