Last active
August 29, 2015 14:03
-
-
Save quattro/07726a0f173823e243ea to your computer and use it in GitHub Desktop.
haplotype w/ ngs reads em
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python2.7 | |
import argparse | |
import math | |
import os | |
import sys | |
import traceback | |
from collections import defaultdict, Counter | |
import numpy | |
import pysam | |
from Bio.SeqIO import parse | |
from numpy import random as nrdm | |
from numpy import ma | |
from scipy import stats | |
from scipy.special import gamma | |
class Model(object): | |
def __init__(self, sam_file, reads, population, coverage): | |
mapped_reads = [rec for rec in sam_file if not rec.is_unmapped] | |
dreads = dict([(read.id, str(read.seq)) for read in reads]) | |
self.read_counts = Counter(dreads.values()) | |
self.population = dict([(rec.id, rec) for rec in population]) | |
self.names = list(sam_file.references) | |
self.var_to_idx = dict([(variant, idx) for idx, variant in enumerate(self.names)]) | |
self.alignments = dict() | |
self.paired = dict() | |
self.coverage = coverage | |
self.snps = self._get_snps(1244, self.population.values()) | |
for record in mapped_reads: | |
vname = self.var_from_index(record.tid) | |
variant = self.population[vname] | |
seq = dreads[record.qname] | |
vec = self.build_algn_vec(record, seq, variant) | |
snp_mask = self.snps[record.pos:record.pos + len(vec)] | |
self.alignments[record] = ma.array(vec, mask=snp_mask) | |
self.paired[(record.seq, vname)] = record | |
return | |
def __iter__(self): | |
return iter(self.read_counts) | |
def count(self, seq): | |
return self.read_counts[seq] | |
def total(self): | |
return float(sum(self.read_counts.values())) | |
def var_index(self, vname): | |
return self.var_to_idx[vname] | |
def var_from_index(self, index): | |
return self.names[index] | |
def get_algn_vec(self, record): | |
return self.alignments.get(record, None) | |
def get_record(self, seq, vname): | |
return self.paired.get((seq, vname), None) | |
def build_algn_vec(self, read, seq, ref): | |
keto = set("GT") | |
purine = set("AG") | |
pyrimidine = set("CT") | |
func = {"A": lambda x: int(x == "A"), | |
"C": lambda x: int(x == "C"), | |
"G": lambda x: int(x == "G"), | |
"T": lambda x: int(x == "T"), | |
"R": lambda x: int(x in purine), | |
"Y": lambda x: int(x in pyrimidine), | |
"K": lambda x: int(x in keto), | |
"-": lambda x: 1, | |
"N": lambda x: 1} | |
i = read.pos | |
j = read.qstart | |
vec = list() | |
for op, count in read.cigar: | |
if op == 0: # match | |
for k in range(count): | |
vec.append(func[ref[i]](seq[j])) | |
i += 1 | |
j += 1 | |
elif op == 2: #deletion | |
i += count | |
vec.extend([0]*count) | |
return numpy.array(vec) | |
def _get_snps(self, length, refs): | |
counts = defaultdict(lambda : defaultdict(int)) | |
for idx in range(length): | |
for ref in refs: | |
counts[idx][ref[idx]] += 1 | |
snps = numpy.array([int(len(counts[idx]) > 1) for idx in range(length)]) | |
return snps | |
def em(model, max_iter=100): | |
min_eps = numpy.finfo(float).tiny | |
# init our probability distributions | |
# probability of a read_pair given an alignment to an isoform | |
# independent bernoulli trials | |
prob_r_v = defaultdict(float) | |
# probability of an isoform (mixing weights) | |
# use Dirichlet prior with alpha = 1.0 (uniform) | |
n_var = len(model.population) | |
dir_alphas = numpy.array([1.0] * n_var) | |
mix_weights = nrdm.dirichlet(dir_alphas) | |
# error-rate for each position | |
# use beta prior with mean error-rate = 0.01 and var = 0.0001 | |
# this means a = 0.98, b = 97.02 | |
variant_len = 1244 #len(variants[0]) | |
b_alphas = numpy.array([1.0] * variant_len) | |
b_betas = numpy.array([100.0] * variant_len) | |
beta = stats.beta(b_alphas, b_betas) | |
error_rates = beta.rvs() | |
total_reads = model.total() | |
def _get_prob(sq, vnme, mdel, err): | |
record = mdel.get_record(sq, vnme) | |
if record is not None: | |
vec = mdel.get_algn_vec(record) | |
# only get positions with SNPs. | |
err_vec = err[record.pos:record.pos + len(vec)] | |
bad = numpy.log(err_vec) | |
good = numpy.log(1.0 - 4.0*err_vec) | |
return max(math.exp(ma.dot(bad, 1.0 - vec) + ma.dot(good, vec)), min_eps) | |
else: | |
return min_eps | |
# calculate initial probability of seeing a read given a variant | |
for seq in model: | |
for vname in model.population: | |
prob_r_v[(seq, vname)] = _get_prob(seq, vname, model, error_rates) | |
likelihood = nlikelihood = -sys.maxint | |
delta = 1.0 | |
em_eps = 0.001 | |
r_rv = defaultdict(float) | |
for iter in range(max_iter): | |
likelihood = nlikelihood | |
if delta < em_eps: | |
break | |
# E-step | |
for seq in model: | |
denom = 0.0 | |
for vname in model.population: | |
vid = model.var_index(vname) | |
denom += mix_weights[vid] * prob_r_v[(seq, vname)] | |
denom = max(denom, min_eps) | |
for vname in model.population: | |
vid = model.var_index(vname) | |
r_rv[(seq, vname)] = mix_weights[vid] * prob_r_v[(seq, vname)] / denom | |
# M-step (MAP estimation) | |
vsum = numpy.zeros(len(model.population)) | |
eps_sum = numpy.zeros(variant_len) | |
for seq in model: | |
count = model.count(seq) | |
for vname in model.population: | |
vid = model.var_index(vname) | |
responsibility = r_rv[(seq, vname)] | |
vsum[vid] += count * responsibility | |
record = model.get_record(seq, vname) | |
if record is not None: | |
vec = model.get_algn_vec(record) | |
eps_sum[record.pos:record.pos + len(vec)] += count * responsibility * vec | |
# new variant probabilities | |
mix_weights = (vsum + dir_alphas - 1.0) / (total_reads + dir_alphas.sum() - n_var) | |
# new error-rates | |
error_rates = (model.coverage - eps_sum + b_alphas - 1.0) /\ | |
(4.0 * (model.coverage + b_alphas + b_betas - 2.0)) | |
# update the probability of observing a read given a variant | |
for seq in model: | |
for vname in model.population: | |
prob_r_v[(seq, vname)] = _get_prob(seq, vname, model, error_rates) | |
# MAP estimation so use log-priors | |
nlikelihood = log_likelihood(model, mix_weights, prob_r_v) +\ | |
numpy.log(beta.pdf(error_rates)).sum() + log_dirpdf(mix_weights, dir_alphas) | |
delta = math.fabs(nlikelihood - likelihood) | |
# This should always be the case; if it isn't there is a bug | |
assert nlikelihood >= likelihood | |
return mix_weights, error_rates, nlikelihood | |
def log_dirpdf(x, alpha): | |
p = (x**(alpha-1.0)).prod() | |
a = gamma(alpha).prod() | |
return math.log(math.gamma(sum(alpha)) / p * a) | |
def log_likelihood(model, mix_weights, prob_r_v): | |
ll = 0.0 | |
for seq in model: | |
rll = 0.0 | |
for vname in model.population: | |
vid = model.var_index(vname) | |
rll += mix_weights[vid] * prob_r_v[(seq, vname)] | |
ll += model.count(seq) * math.log(rll) | |
return ll | |
def main(args): | |
argp = argparse.ArgumentParser(description="EM Algorithm for QSPS.") | |
argp.add_argument("bam_file", help="BAM file containing reads.") | |
argp.add_argument("read_file", type=argparse.FileType("r"), | |
help="FASTA file containing sequencing reads.") | |
argp.add_argument("pop_file", type=argparse.FileType("r"), | |
help="FASTA file containing viral population.") | |
argp.add_argument("cov_file", type=argparse.FileType("r"), | |
help="Bedtools coverage file containing.") | |
argp.add_argument("-r", "--restarts", type=int, default=0, | |
help="Total number of random restarts.") | |
argp.add_argument("-m", "--max_iter", type=int, default=100, | |
help="Maximum number of iterations for EM.") | |
argp.add_argument("-o", "--output", type=argparse.FileType("w"), | |
default=sys.stdout, help="Output file.") | |
args = argp.parse_args(args) | |
sam_file = pysam.Samfile(args.bam_file) | |
coverage = numpy.zeros(1244) | |
for row in args.cov_file: | |
_, pos, count = row.split("\t") | |
coverage[int(pos) - 1] += float(int(count)) | |
model = Model(sam_file, parse(args.read_file, "fasta"), parse(args.pop_file, "fasta"), coverage) | |
# run EM | |
mix_weights, error_rates, ll = None, None, 0 | |
best_w, best_e, best_ll = None, None, -sys.maxint | |
try: | |
for _ in range(args.restarts + 1): | |
mix_weights, error_rates, ll = em(model, args.max_iter) | |
if ll >= best_ll: | |
best_w = mix_weights | |
best_e = error_rates | |
best_ll = ll | |
except Exception as e: | |
traceback.print_exc() | |
return 1 | |
# Output results | |
args.output.write("Variant, Probability") | |
args.output.write(os.linesep) | |
out_vals = [] | |
for variant in model.population: | |
vid = model.var_index(variant) | |
out_vals.append((variant, best_w[vid], os.linesep)) | |
for out_val in sorted(out_vals, key=lambda x: x[1], reverse=True): | |
out = "{0},{1}{2}".format(*out_val) | |
args.output.write(out) | |
args.output.write("Error-rate per position") | |
args.output.write(os.linesep) | |
for idx, err in enumerate(best_e): | |
out = "{0},{1}{2}".format(idx, err, os.linesep) | |
args.output.write(out) | |
return 0 | |
if __name__ == '__main__': | |
sys.exit(main(sys.argv[1:])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment