Skip to content

Instantly share code, notes, and snippets.

@standage
Last active May 11, 2018 18:33
Show Gist options
  • Save standage/f017016a1fd98ceb8f93976c2965073e to your computer and use it in GitHub Desktop.
Save standage/f017016a1fd98ceb8f93976c2965073e to your computer and use it in GitHub Desktop.
Computing variant likelihoods
#!/usr/bin/env python
import argparse
from collections import defaultdict
import khmer
from math import log
import re
import scipy.stats
import sys
def load_smallcounttable(filename):
print('Loading counttable "{:s}"...'.format(filename), end='',
file=sys.stderr, flush=True)
ct = khmer.SmallCounttable.load(filename)
print('done!', file=sys.stderr, flush=True)
return ct
def load_counttable(filename):
print('Loading counttable "{:s}"...'.format(filename), end='',
file=sys.stderr, flush=True)
ct = khmer.Counttable.load(filename)
print('done!', file=sys.stderr, flush=True)
return ct
def load_nodetable(filename):
print('Loading nodetable "{:s}"...'.format(filename), end='',
file=sys.stderr, flush=True)
ct = khmer.Nodetable.load(filename)
print('done!', file=sys.stderr, flush=True)
return ct
def filter_refr(akmers, refr):
newa, newr = list(), list()
newa = [a for a in akmers if refr.get(a) == 0]
for a in akmers:
if refr.get(a) == 0:
newa.append(a)
return newa, newr
def get_abundances(kmers, casecounts, controlcounts):
abundances = list()
for _ in range(len(controlcounts) + 1):
abundances.append(list())
for kmer in kmers:
a = casecounts.get(kmer)
abundances[0].append(a)
for i in range(len(controlcounts)):
a = controlcounts[i].get(kmer)
abundances[i+1].append(a)
return abundances
def set_error_rates(error, nsamples):
if isinstance(error, float):
errors = [error] * nsamples
elif isinstance(error, list):
assert len(error) == nsamples
for e in error:
assert isinstance(e, float)
errors = error
else:
message = 'variable {} doesn\'t quack like a float'.format(error)
message += ' or a list of floats'
raise ValueError(message)
return errors
def abund_log_prob(genotype, abundance, refrabund=None, mean=30.0, sd=8.0, error=0.01):
if genotype == 0:
erate = error * mean
if refrabund:
erate *= refrabund
# print('DEBUG erate:', erate, refrabund, error, mean, file=sys.stderr)
else:
erate *= 0.1
return abundance * log(erate)
if genotype == 1:
return scipy.stats.norm.logpdf(abundance, mean / 2, sd / 2)
if genotype == 2:
return scipy.stats.norm.logpdf(abundance, mean, sd)
def likelihood_denovo(abunds, refrabunds=None, mean=30.0, sd=8.0, error=0.01):
errors = set_error_rates(error, nsamples=len(abunds))
logsum = 0.0
# Case
for abund in abunds[0]:
logsum += abund_log_prob(1, abund, mean=mean, sd=sd)
# Controls
for alt, err, in zip(abunds[1:], errors[1:]):
for a, r in zip(alt, refrabunds):
logsum += abund_log_prob(0, a, refrabund=r, error=err)
return logsum
def likelihood_false(abunds, refrabunds=None, mean=30.0, error=0.01):
errors = set_error_rates(error, nsamples=len(abunds))
logsum = 0.0
for abundlist, e in zip(abunds, errors):
for abund, refr in zip(abundlist, refrabunds):
logsum += abund_log_prob(0, abund, refrabund=refr, mean=mean, error=e)
return logsum
def likelihood_inherited(abunds, mean=30.0, sd=8.0, error=0.01):
scenarios = [
(1, 0, 1), (1, 0, 2),
(1, 1, 0), (1, 1, 1), (1, 1, 2),
(1, 2, 0), (1, 2, 1),
(2, 1, 1), (2, 1, 2),
(2, 2, 1), (2, 2, 2),
]
errors = set_error_rates(error, nsamples=3)
logsum = 0.0
abundances = zip(abunds[0], abunds[1], abunds[2])
for a_c, a_m, a_f in abundances:
maxval = None
for g_c, g_m, g_f in scenarios:
testsum = abund_log_prob(g_c, a_c, mean=mean, sd=sd, error=errors[0]) + \
abund_log_prob(g_m, a_m, mean=mean, sd=sd, error=errors[1]) + \
abund_log_prob(g_f, a_f, mean=mean, sd=sd, error=errors[2]) + \
log(1.0 / 15.0)
if maxval is None or testsum > maxval:
maxval = testsum
logsum += maxval
return log(15.0 / 11.0) + logsum # 1 / (11/15)
def joinlist(thelist):
if len(thelist) == 0:
return '.'
else:
return ','.join([str(v) for v in thelist])
parser = argparse.ArgumentParser()
parser.add_argument('--case', type=load_counttable)
parser.add_argument('--controls', type=load_counttable, nargs='+')
parser.add_argument('--mu', type=float, default=40.0)
parser.add_argument('--sigma', type=float, default=8.0)
parser.add_argument('--epsilon', type=float, default=0.01)
parser.add_argument('--out', type=argparse.FileType('w'), default=sys.stdout)
parser.add_argument('--refr', type=load_smallcounttable)
parser.add_argument('--case-min', type=int, default=5)
parser.add_argument('vcf', type=argparse.FileType('r'))
args = parser.parse_args()
calls_by_partition = defaultdict(list)
for record in args.vcf:
if record.startswith('#'):
print(record, end='', file=args.out)
fields = record.strip().split('\t')
windowmatch = re.search('ALTWINDOW=([^;\n]+)', record)
if not windowmatch:
continue
window = windowmatch.group(1)
if len(window) < args.case.ksize():
message = 'WARNING: ignoring window "{:s}"'.format(window)
message += ', shorter than ksize {:d}'.format(args.case.ksize())
print(message, file=sys.stderr)
continue
windowmatch = re.search('REFRWINDOW=([^;\n]+)', record)
refrwindow = windowmatch.group(1) if windowmatch else None
kmers = args.case.get_kmers(window)
refrkmers = [None] * len(kmers)
if refrwindow is not None and len(refrwindow) == len(window) and args.refr:
refrkmers = args.refr.get_kmers(refrwindow)
dropped = 0
if args.refr:
nkmers = len(kmers)
kmerpairs = [(k, r) for k, r in zip(kmers, refrkmers) if args.refr.get(k) == 0]
kmers = [k for k, r in kmerpairs]
refrkmers = [r for k, r in kmerpairs]
ndropped = nkmers - len(kmers)
# if len(kmers) < args.case.ksize() / 2:
# message = 'WARNING: ignoring window "{:s}"'.format(window)
# message += ', more than half of the spanning k-mers are present '
# message += 'in the reference genome'
# print(message, file=sys.stderr)
# continue
if len(kmers) == 0:
message = 'skipping variant with window "{:s}"'.format(window)
message += '; all spanning k-mers are present in reference genome'
print(message, file=sys.stderr)
continue
if refrkmers[0] is None:
refrkmers = None
abunds = get_abundances(kmers, args.case, args.controls)
abovethresh = [a for a in abunds[0] if a > args.case_min]
if len(abovethresh) == 0:
# passenger call
continue
refrabunds = [None] * len(abunds)
if args.refr and refrkmers:
refrabunds = [args.refr.get(k) for k in refrkmers]
likedn = likelihood_denovo(abunds, refrabunds=refrabunds, mean=args.mu, sd=args.sigma, error=args.epsilon)
likefp = likelihood_false(abunds, refrabunds=refrabunds, mean=args.mu, error=args.epsilon)
likeih = likelihood_inherited(abunds, mean=args.mu, sd=args.sigma, error=args.epsilon)
likescore = likedn - max(likefp, likeih)
fields[7] += ';LIKESCORE={:.2f};LLDN={:.2f}'.format(likescore, likedn)
fields[7] += ';LLFP={:.2f};LLIH={:.2f}'.format(likefp, likeih)
fields[7] += ';DROPPED={:d}'.format(ndropped)
fields[7] += ';CASE={:s}'.format(joinlist(abunds[0]))
fields[7] += ';CTRL1={:s}'.format(joinlist(abunds[1]))
fields[7] += ';CTRL2={:s}'.format(joinlist(abunds[2]))
part = re.search('PART=([^;\n]+)', fields[7]).group(1)
calls_by_partition[part].append(fields)
allcalls = list()
for part, calls in calls_by_partition.items():
scores = [float(re.search('LIKESCORE=([^;\n]+)', c[7]).group(1)) for c in calls]
maxscore = max(scores)
for call, score in zip(calls, scores):
if score < maxscore:
call[6] = 'PartitionScore'
allcalls.append(call)
allcalls.sort(key=lambda c: float(re.search('LIKESCORE=([^;\n]+)', c[7]).group(1)), reverse=True)
for call in allcalls:
print(*call, sep='\t', file=args.out)
#!/usr/bin/env python
import re
import sys
records = list()
for line in sys.stdin:
record = line.strip()
match = re.search('LIKESCORE=([^\n;]+)', record)
like = float(match.group(1)) if match else float('-inf')
records.append((like, record))
records.sort(key=lambda r: r[0], reverse=True)
for record in records:
print(record[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment