Skip to content

Instantly share code, notes, and snippets.

@quattro
Last active August 29, 2015 14:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save quattro/07726a0f173823e243ea to your computer and use it in GitHub Desktop.
Save quattro/07726a0f173823e243ea to your computer and use it in GitHub Desktop.
haplotype w/ ngs reads em
#! /usr/bin/env python2.7
import argparse
import math
import os
import sys
import traceback
from collections import defaultdict, Counter
import numpy
import pysam
from Bio.SeqIO import parse
from numpy import random as nrdm
from numpy import ma
from scipy import stats
from scipy.special import gamma
class Model(object):
def __init__(self, sam_file, reads, population, coverage):
mapped_reads = [rec for rec in sam_file if not rec.is_unmapped]
dreads = dict([(read.id, str(read.seq)) for read in reads])
self.read_counts = Counter(dreads.values())
self.population = dict([(rec.id, rec) for rec in population])
self.names = list(sam_file.references)
self.var_to_idx = dict([(variant, idx) for idx, variant in enumerate(self.names)])
self.alignments = dict()
self.paired = dict()
self.coverage = coverage
self.snps = self._get_snps(1244, self.population.values())
for record in mapped_reads:
vname = self.var_from_index(record.tid)
variant = self.population[vname]
seq = dreads[record.qname]
vec = self.build_algn_vec(record, seq, variant)
snp_mask = self.snps[record.pos:record.pos + len(vec)]
self.alignments[record] = ma.array(vec, mask=snp_mask)
self.paired[(record.seq, vname)] = record
return
def __iter__(self):
return iter(self.read_counts)
def count(self, seq):
return self.read_counts[seq]
def total(self):
return float(sum(self.read_counts.values()))
def var_index(self, vname):
return self.var_to_idx[vname]
def var_from_index(self, index):
return self.names[index]
def get_algn_vec(self, record):
return self.alignments.get(record, None)
def get_record(self, seq, vname):
return self.paired.get((seq, vname), None)
def build_algn_vec(self, read, seq, ref):
keto = set("GT")
purine = set("AG")
pyrimidine = set("CT")
func = {"A": lambda x: int(x == "A"),
"C": lambda x: int(x == "C"),
"G": lambda x: int(x == "G"),
"T": lambda x: int(x == "T"),
"R": lambda x: int(x in purine),
"Y": lambda x: int(x in pyrimidine),
"K": lambda x: int(x in keto),
"-": lambda x: 1,
"N": lambda x: 1}
i = read.pos
j = read.qstart
vec = list()
for op, count in read.cigar:
if op == 0: # match
for k in range(count):
vec.append(func[ref[i]](seq[j]))
i += 1
j += 1
elif op == 2: #deletion
i += count
vec.extend([0]*count)
return numpy.array(vec)
def _get_snps(self, length, refs):
counts = defaultdict(lambda : defaultdict(int))
for idx in range(length):
for ref in refs:
counts[idx][ref[idx]] += 1
snps = numpy.array([int(len(counts[idx]) > 1) for idx in range(length)])
return snps
def em(model, max_iter=100):
min_eps = numpy.finfo(float).tiny
# init our probability distributions
# probability of a read_pair given an alignment to an isoform
# independent bernoulli trials
prob_r_v = defaultdict(float)
# probability of an isoform (mixing weights)
# use Dirichlet prior with alpha = 1.0 (uniform)
n_var = len(model.population)
dir_alphas = numpy.array([1.0] * n_var)
mix_weights = nrdm.dirichlet(dir_alphas)
# error-rate for each position
# use beta prior with mean error-rate = 0.01 and var = 0.0001
# this means a = 0.98, b = 97.02
variant_len = 1244 #len(variants[0])
b_alphas = numpy.array([1.0] * variant_len)
b_betas = numpy.array([100.0] * variant_len)
beta = stats.beta(b_alphas, b_betas)
error_rates = beta.rvs()
total_reads = model.total()
def _get_prob(sq, vnme, mdel, err):
record = mdel.get_record(sq, vnme)
if record is not None:
vec = mdel.get_algn_vec(record)
# only get positions with SNPs.
err_vec = err[record.pos:record.pos + len(vec)]
bad = numpy.log(err_vec)
good = numpy.log(1.0 - 4.0*err_vec)
return max(math.exp(ma.dot(bad, 1.0 - vec) + ma.dot(good, vec)), min_eps)
else:
return min_eps
# calculate initial probability of seeing a read given a variant
for seq in model:
for vname in model.population:
prob_r_v[(seq, vname)] = _get_prob(seq, vname, model, error_rates)
likelihood = nlikelihood = -sys.maxint
delta = 1.0
em_eps = 0.001
r_rv = defaultdict(float)
for iter in range(max_iter):
likelihood = nlikelihood
if delta < em_eps:
break
# E-step
for seq in model:
denom = 0.0
for vname in model.population:
vid = model.var_index(vname)
denom += mix_weights[vid] * prob_r_v[(seq, vname)]
denom = max(denom, min_eps)
for vname in model.population:
vid = model.var_index(vname)
r_rv[(seq, vname)] = mix_weights[vid] * prob_r_v[(seq, vname)] / denom
# M-step (MAP estimation)
vsum = numpy.zeros(len(model.population))
eps_sum = numpy.zeros(variant_len)
for seq in model:
count = model.count(seq)
for vname in model.population:
vid = model.var_index(vname)
responsibility = r_rv[(seq, vname)]
vsum[vid] += count * responsibility
record = model.get_record(seq, vname)
if record is not None:
vec = model.get_algn_vec(record)
eps_sum[record.pos:record.pos + len(vec)] += count * responsibility * vec
# new variant probabilities
mix_weights = (vsum + dir_alphas - 1.0) / (total_reads + dir_alphas.sum() - n_var)
# new error-rates
error_rates = (model.coverage - eps_sum + b_alphas - 1.0) /\
(4.0 * (model.coverage + b_alphas + b_betas - 2.0))
# update the probability of observing a read given a variant
for seq in model:
for vname in model.population:
prob_r_v[(seq, vname)] = _get_prob(seq, vname, model, error_rates)
# MAP estimation so use log-priors
nlikelihood = log_likelihood(model, mix_weights, prob_r_v) +\
numpy.log(beta.pdf(error_rates)).sum() + log_dirpdf(mix_weights, dir_alphas)
delta = math.fabs(nlikelihood - likelihood)
# This should always be the case; if it isn't there is a bug
assert nlikelihood >= likelihood
return mix_weights, error_rates, nlikelihood
def log_dirpdf(x, alpha):
p = (x**(alpha-1.0)).prod()
a = gamma(alpha).prod()
return math.log(math.gamma(sum(alpha)) / p * a)
def log_likelihood(model, mix_weights, prob_r_v):
ll = 0.0
for seq in model:
rll = 0.0
for vname in model.population:
vid = model.var_index(vname)
rll += mix_weights[vid] * prob_r_v[(seq, vname)]
ll += model.count(seq) * math.log(rll)
return ll
def main(args):
argp = argparse.ArgumentParser(description="EM Algorithm for QSPS.")
argp.add_argument("bam_file", help="BAM file containing reads.")
argp.add_argument("read_file", type=argparse.FileType("r"),
help="FASTA file containing sequencing reads.")
argp.add_argument("pop_file", type=argparse.FileType("r"),
help="FASTA file containing viral population.")
argp.add_argument("cov_file", type=argparse.FileType("r"),
help="Bedtools coverage file containing.")
argp.add_argument("-r", "--restarts", type=int, default=0,
help="Total number of random restarts.")
argp.add_argument("-m", "--max_iter", type=int, default=100,
help="Maximum number of iterations for EM.")
argp.add_argument("-o", "--output", type=argparse.FileType("w"),
default=sys.stdout, help="Output file.")
args = argp.parse_args(args)
sam_file = pysam.Samfile(args.bam_file)
coverage = numpy.zeros(1244)
for row in args.cov_file:
_, pos, count = row.split("\t")
coverage[int(pos) - 1] += float(int(count))
model = Model(sam_file, parse(args.read_file, "fasta"), parse(args.pop_file, "fasta"), coverage)
# run EM
mix_weights, error_rates, ll = None, None, 0
best_w, best_e, best_ll = None, None, -sys.maxint
try:
for _ in range(args.restarts + 1):
mix_weights, error_rates, ll = em(model, args.max_iter)
if ll >= best_ll:
best_w = mix_weights
best_e = error_rates
best_ll = ll
except Exception as e:
traceback.print_exc()
return 1
# Output results
args.output.write("Variant, Probability")
args.output.write(os.linesep)
out_vals = []
for variant in model.population:
vid = model.var_index(variant)
out_vals.append((variant, best_w[vid], os.linesep))
for out_val in sorted(out_vals, key=lambda x: x[1], reverse=True):
out = "{0},{1}{2}".format(*out_val)
args.output.write(out)
args.output.write("Error-rate per position")
args.output.write(os.linesep)
for idx, err in enumerate(best_e):
out = "{0},{1}{2}".format(idx, err, os.linesep)
args.output.write(out)
return 0
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment