Created
April 24, 2017 17:11
-
-
Save dkoslicki/09d7f4326bd0350193848b8b6f61e5d8 to your computer and use it in GitHub Desktop.
Testing typical bottom-k min hash estimate of Jaccard index versus containment estimate of Jaccard index
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sourmash_lib | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from pybloom import BloomFilter | |
n = 10000 # sequence length | |
ksize = 10 # k-mer length | |
h = 100 # number of hashes in sketch | |
i_range = range(1, 50000, 500) # range of intersection sizes | |
#i_range = [10000] | |
true_jaccards = np.zeros(len(i_range)) | |
estimate_jaccards = np.zeros(len(i_range)) | |
corrected_jaccards = np.zeros(len(i_range)) | |
containment_jaccards = np.zeros(len(i_range)) | |
it = 0 | |
for i_size in i_range: | |
# Append a common string to two different random strings (true jaccard will be ~ i_size/n) | |
common_string = ''.join(np.random.choice(['A', 'C', 'T', 'G'], i_size)) | |
seq1 = ''.join(np.random.choice(['A', 'C', 'T', 'G'], n)) + common_string | |
seq2 = ''.join(np.random.choice(['A', 'C', 'T', 'G'], n/1000)) + common_string # Make seq2 a smaller sequence than seq1 | |
# Calculate exact Jaccard index | |
kmers1 = set() | |
kmers2 = set() | |
for i in xrange(len(seq1) - ksize + 1): | |
kmers1.add(seq1[i:i+ksize]) | |
for i in xrange(len(seq2) - ksize + 1): | |
kmers2.add(seq2[i:i+ksize]) | |
true_jaccard = len(kmers1.intersection(kmers2)) / float(len(kmers1.union(kmers2))) | |
# in case E1.jaccard(E2) is computing containment. It isn't (using commit 6609b3a) | |
#true_jaccard = len(kmers1.intersection(kmers2)) / float(len(kmers1)) | |
true_jaccards[it] = true_jaccard | |
# Calculate sourmash estimate of Jaccard index | |
E1 = sourmash_lib.MinHash(n=h, ksize=ksize) | |
E2 = sourmash_lib.MinHash(n=h, ksize=ksize) | |
E1.add_sequence(seq1) | |
E2.add_sequence(seq2) | |
estimate_jaccard = E1.jaccard(E2) | |
estimate_jaccards[it] = estimate_jaccard | |
# corrected version, making sure that sourmash is doing the right Jaccard estimate | |
E1_mins = set(E1.get_mins()) | |
E2_mins = set(E2.get_mins()) | |
corrected_estimate_jaccard = len(set(sorted(list(E1_mins.union(E2_mins)))[0:h]).intersection( | |
E1_mins.intersection(E2_mins))) / float(h) | |
corrected_jaccards[it] = corrected_estimate_jaccard | |
# Containment version. Sourmash doesn't return the actual k-mers that were hashed to the minimum values | |
# So instead I need to hash every one of the k-mers and treat this as the other set | |
seq1_kmers_minhash = sourmash_lib.MinHash(n=len(kmers1), ksize=ksize) | |
for kmer in kmers1: | |
seq1_kmers_minhash.add(kmer) | |
# This is playing the role of a bloom filter, where I can be confident that seq1_kmer_hashes | |
# contains all the elements of seq1 (playing the role of the metagenome here) | |
seq1_kmers_hashes = set(seq1_kmers_minhash.get_mins()) | |
# get the mins from the genome | |
seq2_kmers_hashes = E2_mins | |
#containment_est = len(test_kmers_hashes.intersection(seq1_kmers_hashes)) / float(h) | |
# Bloom filter approach | |
f = BloomFilter(capacity=len(seq1_kmers_hashes), error_rate=0.001) | |
for val in seq1_kmers_hashes: | |
f.add(val) | |
int_est = 0 | |
for val in seq2_kmers_hashes: | |
if val in f: | |
int_est += 1 | |
containment_est = int_est / float(h) | |
containment_est_jaccard = \ | |
len(kmers2) * containment_est / \ | |
(len(kmers2) + len(kmers1) - len(kmers2) * containment_est) # could use Hyperloglog here to estimate len(kmers1) | |
# Using the built in count function (over-estimate?) | |
#containment_est_jaccard = \ | |
# len(kmers2) * containment_est / \ | |
# (len(kmers2) + f.count - len(kmers2) * containment_est) | |
containment_jaccards[it] = containment_est_jaccard | |
it += 1 | |
differences = true_jaccards - estimate_jaccards | |
sorted_true = sorted(true_jaccards) | |
sorted_estimates = np.array([x for (y, x) in sorted(zip(true_jaccards, estimate_jaccards), key=lambda pair: pair[0])]) | |
sorted_differences = sorted_true - sorted_estimates | |
plt.figure() | |
plt.plot(sorted_true, sorted_differences) | |
axes = plt.gca() | |
axes.set_ylim([np.min(plt.yticks()[0])*1.5, np.max(plt.yticks()[0])*1.5]) | |
plt.title('True - estimate Jaccard index') | |
plt.text(0, 0, 'Underestimate', rotation=90, horizontalalignment='center', verticalalignment='bottom', multialignment='center', color='b', fontsize=14) | |
plt.text(0, 0, 'Overestimate', rotation=90, horizontalalignment='center', verticalalignment='top', multialignment='center', color='r', fontsize=14) | |
plt.axhline(0, color='black', linestyle='dashed', linewidth=2) | |
plt.ylabel('Difference') | |
plt.xlabel('True Jaccard index') | |
plt.savefig('Differences.png') | |
# Do a true vs sourmash estimate plot | |
plt.figure() | |
f, ax = plt.subplots() | |
ax.plot([0, 1], [0, 1], ls="--", c=".3") | |
ax.plot(sorted_true, sorted_estimates) | |
plt.ylabel('Estimate Jaccard') | |
plt.xlabel('True Jaccard') | |
plt.title('Typical Jaccard estimate') | |
plt.savefig('TrueVsEstimate.png') | |
# Do a relative error plot | |
plt.figure() | |
plt.plot(sorted_true, sorted_differences / sorted_true) | |
axes = plt.gca() | |
plt.axhline(0, color='black', linestyle='dashed', linewidth=2) | |
plt.ylabel('Relative error') | |
plt.xlabel('True Jaccard index') | |
plt.savefig('RelativeError.png') | |
plt.figure() | |
n, bins, patches = plt.hist(differences, 50, normed=1, facecolor='green', alpha=0.75) | |
plt.axvline(0, color='b', linestyle='dashed', linewidth=2) | |
plt.axvline(np.mean(differences), color='black', linestyle='dashed', linewidth=2) | |
plt.title('Histogram of (true - estimate) Jaccard index\n Mean: %f' % np.mean(differences)) | |
plt.text(0, max(plt.yticks()[0])-1, 'Underestimate', rotation=0, horizontalalignment='left', verticalalignment='top', multialignment='left', color='b', fontsize=14) | |
plt.text(plt.xticks()[0][1], max(plt.yticks()[0])-1, 'Overestimate', rotation=0, horizontalalignment='left', verticalalignment='top', multialignment='left', color='r', fontsize=14) | |
plt.xlabel('Difference') | |
plt.savefig('Histogram.png') | |
plt.figure() | |
corrected_differences = true_jaccards - corrected_jaccards | |
n, bins, patches = plt.hist(corrected_differences, 50, normed=1, facecolor='green', alpha=0.75) | |
plt.axvline(0, color='b', linestyle='dashed', linewidth=2) | |
plt.axvline(np.mean(corrected_differences), color='black', linestyle='dashed', linewidth=2) | |
plt.title('Histogram of (true - corrected estimate) Jaccard index\n Mean: %f' % np.mean(corrected_differences)) | |
plt.text(0, max(plt.yticks()[0])-1, 'Underestimate', rotation=0, horizontalalignment='left', verticalalignment='top', multialignment='left', color='b', fontsize=14) | |
plt.text(plt.xticks()[0][1], max(plt.yticks()[0])-1, 'Overestimate', rotation=0, horizontalalignment='left', verticalalignment='top', multialignment='left', color='r', fontsize=14) | |
plt.xlabel('Difference') | |
plt.savefig('CorrectedHistogram.png') | |
# Containment guys | |
plt.figure() | |
sorted_true = sorted(true_jaccards) | |
sorted_containment_estimates = np.array([x for (y, x) in sorted(zip(true_jaccards, containment_jaccards), key=lambda pair: pair[0])]) | |
containment_differences = sorted_true - sorted_containment_estimates | |
#containment_differences = true_jaccards - containment_jaccards | |
n, bins, patches = plt.hist(containment_differences, 50, normed=1, facecolor='green', alpha=0.75) | |
plt.axvline(0, color='b', linestyle='dashed', linewidth=2) | |
plt.axvline(np.mean(containment_differences), color='black', linestyle='dashed', linewidth=2) | |
plt.title('Histogram of (true - corrected estimate) Jaccard index\n Mean: %f' % np.mean(containment_differences)) | |
plt.text(0, max(plt.yticks()[0])-1, 'Underestimate', rotation=0, horizontalalignment='left', verticalalignment='top', multialignment='left', color='b', fontsize=14) | |
plt.text(plt.xticks()[0][1], max(plt.yticks()[0])-1, 'Overestimate', rotation=0, horizontalalignment='left', verticalalignment='top', multialignment='left', color='r', fontsize=14) | |
plt.xlabel('Difference') | |
plt.savefig('ContainmentHistogram.png') | |
# Do a true vs containment estimate plot | |
plt.figure() | |
f, ax = plt.subplots() | |
ax.plot([0, 1], [0, 1], ls="--", c=".3") | |
ax.plot(sorted_true, sorted_containment_estimates) | |
plt.ylabel('Estimate Jaccard') | |
plt.xlabel('True Jaccard') | |
plt.title('Jaccard estimate via containment') | |
plt.savefig('ContainmentTrueVsEstimate.png') | |
# Do a relative error plot | |
plt.figure() | |
plt.plot(sorted_true, containment_differences / sorted_true) | |
axes = plt.gca() | |
plt.axhline(0, color='black', linestyle='dashed', linewidth=2) | |
plt.ylabel('Relative error') | |
plt.xlabel('True Jaccard index') | |
plt.savefig('ContainmentRelativeError.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
And here's the important output:
Typical approach histogram:
Typical approach true vs. estimate:
Containment histogram (depicting that the estimate it biased):
Containment approach true vs estimate (depicting that the variance and overall accuracy is much improved):