Last active
May 11, 2018 18:33
-
-
Save standage/f017016a1fd98ceb8f93976c2965073e to your computer and use it in GitHub Desktop.
Computing variant likelihoods
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse | |
from collections import defaultdict | |
import khmer | |
from math import log | |
import re | |
import scipy.stats | |
import sys | |
def load_smallcounttable(filename): | |
print('Loading counttable "{:s}"...'.format(filename), end='', | |
file=sys.stderr, flush=True) | |
ct = khmer.SmallCounttable.load(filename) | |
print('done!', file=sys.stderr, flush=True) | |
return ct | |
def load_counttable(filename): | |
print('Loading counttable "{:s}"...'.format(filename), end='', | |
file=sys.stderr, flush=True) | |
ct = khmer.Counttable.load(filename) | |
print('done!', file=sys.stderr, flush=True) | |
return ct | |
def load_nodetable(filename): | |
print('Loading nodetable "{:s}"...'.format(filename), end='', | |
file=sys.stderr, flush=True) | |
ct = khmer.Nodetable.load(filename) | |
print('done!', file=sys.stderr, flush=True) | |
return ct | |
def filter_refr(akmers, refr): | |
newa, newr = list(), list() | |
newa = [a for a in akmers if refr.get(a) == 0] | |
for a in akmers: | |
if refr.get(a) == 0: | |
newa.append(a) | |
return newa, newr | |
def get_abundances(kmers, casecounts, controlcounts): | |
abundances = list() | |
for _ in range(len(controlcounts) + 1): | |
abundances.append(list()) | |
for kmer in kmers: | |
a = casecounts.get(kmer) | |
abundances[0].append(a) | |
for i in range(len(controlcounts)): | |
a = controlcounts[i].get(kmer) | |
abundances[i+1].append(a) | |
return abundances | |
def set_error_rates(error, nsamples): | |
if isinstance(error, float): | |
errors = [error] * nsamples | |
elif isinstance(error, list): | |
assert len(error) == nsamples | |
for e in error: | |
assert isinstance(e, float) | |
errors = error | |
else: | |
message = 'variable {} doesn\'t quack like a float'.format(error) | |
message += ' or a list of floats' | |
raise ValueError(message) | |
return errors | |
def abund_log_prob(genotype, abundance, refrabund=None, mean=30.0, sd=8.0, error=0.01): | |
if genotype == 0: | |
erate = error * mean | |
if refrabund: | |
erate *= refrabund | |
# print('DEBUG erate:', erate, refrabund, error, mean, file=sys.stderr) | |
else: | |
erate *= 0.1 | |
return abundance * log(erate) | |
if genotype == 1: | |
return scipy.stats.norm.logpdf(abundance, mean / 2, sd / 2) | |
if genotype == 2: | |
return scipy.stats.norm.logpdf(abundance, mean, sd) | |
def likelihood_denovo(abunds, refrabunds=None, mean=30.0, sd=8.0, error=0.01): | |
errors = set_error_rates(error, nsamples=len(abunds)) | |
logsum = 0.0 | |
# Case | |
for abund in abunds[0]: | |
logsum += abund_log_prob(1, abund, mean=mean, sd=sd) | |
# Controls | |
for alt, err, in zip(abunds[1:], errors[1:]): | |
for a, r in zip(alt, refrabunds): | |
logsum += abund_log_prob(0, a, refrabund=r, error=err) | |
return logsum | |
def likelihood_false(abunds, refrabunds=None, mean=30.0, error=0.01): | |
errors = set_error_rates(error, nsamples=len(abunds)) | |
logsum = 0.0 | |
for abundlist, e in zip(abunds, errors): | |
for abund, refr in zip(abundlist, refrabunds): | |
logsum += abund_log_prob(0, abund, refrabund=refr, mean=mean, error=e) | |
return logsum | |
def likelihood_inherited(abunds, mean=30.0, sd=8.0, error=0.01): | |
scenarios = [ | |
(1, 0, 1), (1, 0, 2), | |
(1, 1, 0), (1, 1, 1), (1, 1, 2), | |
(1, 2, 0), (1, 2, 1), | |
(2, 1, 1), (2, 1, 2), | |
(2, 2, 1), (2, 2, 2), | |
] | |
errors = set_error_rates(error, nsamples=3) | |
logsum = 0.0 | |
abundances = zip(abunds[0], abunds[1], abunds[2]) | |
for a_c, a_m, a_f in abundances: | |
maxval = None | |
for g_c, g_m, g_f in scenarios: | |
testsum = abund_log_prob(g_c, a_c, mean=mean, sd=sd, error=errors[0]) + \ | |
abund_log_prob(g_m, a_m, mean=mean, sd=sd, error=errors[1]) + \ | |
abund_log_prob(g_f, a_f, mean=mean, sd=sd, error=errors[2]) + \ | |
log(1.0 / 15.0) | |
if maxval is None or testsum > maxval: | |
maxval = testsum | |
logsum += maxval | |
return log(15.0 / 11.0) + logsum # 1 / (11/15) | |
def joinlist(thelist): | |
if len(thelist) == 0: | |
return '.' | |
else: | |
return ','.join([str(v) for v in thelist]) | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--case', type=load_counttable) | |
parser.add_argument('--controls', type=load_counttable, nargs='+') | |
parser.add_argument('--mu', type=float, default=40.0) | |
parser.add_argument('--sigma', type=float, default=8.0) | |
parser.add_argument('--epsilon', type=float, default=0.01) | |
parser.add_argument('--out', type=argparse.FileType('w'), default=sys.stdout) | |
parser.add_argument('--refr', type=load_smallcounttable) | |
parser.add_argument('--case-min', type=int, default=5) | |
parser.add_argument('vcf', type=argparse.FileType('r')) | |
args = parser.parse_args() | |
calls_by_partition = defaultdict(list) | |
for record in args.vcf: | |
if record.startswith('#'): | |
print(record, end='', file=args.out) | |
fields = record.strip().split('\t') | |
windowmatch = re.search('ALTWINDOW=([^;\n]+)', record) | |
if not windowmatch: | |
continue | |
window = windowmatch.group(1) | |
if len(window) < args.case.ksize(): | |
message = 'WARNING: ignoring window "{:s}"'.format(window) | |
message += ', shorter than ksize {:d}'.format(args.case.ksize()) | |
print(message, file=sys.stderr) | |
continue | |
windowmatch = re.search('REFRWINDOW=([^;\n]+)', record) | |
refrwindow = windowmatch.group(1) if windowmatch else None | |
kmers = args.case.get_kmers(window) | |
refrkmers = [None] * len(kmers) | |
if refrwindow is not None and len(refrwindow) == len(window) and args.refr: | |
refrkmers = args.refr.get_kmers(refrwindow) | |
dropped = 0 | |
if args.refr: | |
nkmers = len(kmers) | |
kmerpairs = [(k, r) for k, r in zip(kmers, refrkmers) if args.refr.get(k) == 0] | |
kmers = [k for k, r in kmerpairs] | |
refrkmers = [r for k, r in kmerpairs] | |
ndropped = nkmers - len(kmers) | |
# if len(kmers) < args.case.ksize() / 2: | |
# message = 'WARNING: ignoring window "{:s}"'.format(window) | |
# message += ', more than half of the spanning k-mers are present ' | |
# message += 'in the reference genome' | |
# print(message, file=sys.stderr) | |
# continue | |
if len(kmers) == 0: | |
message = 'skipping variant with window "{:s}"'.format(window) | |
message += '; all spanning k-mers are present in reference genome' | |
print(message, file=sys.stderr) | |
continue | |
if refrkmers[0] is None: | |
refrkmers = None | |
abunds = get_abundances(kmers, args.case, args.controls) | |
abovethresh = [a for a in abunds[0] if a > args.case_min] | |
if len(abovethresh) == 0: | |
# passenger call | |
continue | |
refrabunds = [None] * len(abunds) | |
if args.refr and refrkmers: | |
refrabunds = [args.refr.get(k) for k in refrkmers] | |
likedn = likelihood_denovo(abunds, refrabunds=refrabunds, mean=args.mu, sd=args.sigma, error=args.epsilon) | |
likefp = likelihood_false(abunds, refrabunds=refrabunds, mean=args.mu, error=args.epsilon) | |
likeih = likelihood_inherited(abunds, mean=args.mu, sd=args.sigma, error=args.epsilon) | |
likescore = likedn - max(likefp, likeih) | |
fields[7] += ';LIKESCORE={:.2f};LLDN={:.2f}'.format(likescore, likedn) | |
fields[7] += ';LLFP={:.2f};LLIH={:.2f}'.format(likefp, likeih) | |
fields[7] += ';DROPPED={:d}'.format(ndropped) | |
fields[7] += ';CASE={:s}'.format(joinlist(abunds[0])) | |
fields[7] += ';CTRL1={:s}'.format(joinlist(abunds[1])) | |
fields[7] += ';CTRL2={:s}'.format(joinlist(abunds[2])) | |
part = re.search('PART=([^;\n]+)', fields[7]).group(1) | |
calls_by_partition[part].append(fields) | |
allcalls = list() | |
for part, calls in calls_by_partition.items(): | |
scores = [float(re.search('LIKESCORE=([^;\n]+)', c[7]).group(1)) for c in calls] | |
maxscore = max(scores) | |
for call, score in zip(calls, scores): | |
if score < maxscore: | |
call[6] = 'PartitionScore' | |
allcalls.append(call) | |
allcalls.sort(key=lambda c: float(re.search('LIKESCORE=([^;\n]+)', c[7]).group(1)), reverse=True) | |
for call in allcalls: | |
print(*call, sep='\t', file=args.out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import re | |
import sys | |
records = list() | |
for line in sys.stdin: | |
record = line.strip() | |
match = re.search('LIKESCORE=([^\n;]+)', record) | |
like = float(match.group(1)) if match else float('-inf') | |
records.append((like, record)) | |
records.sort(key=lambda r: r[0], reverse=True) | |
for record in records: | |
print(record[1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment