Skip to content

Instantly share code, notes, and snippets.

@dkoslicki
Created April 24, 2017 17:11
Show Gist options
  • Save dkoslicki/09d7f4326bd0350193848b8b6f61e5d8 to your computer and use it in GitHub Desktop.
Save dkoslicki/09d7f4326bd0350193848b8b6f61e5d8 to your computer and use it in GitHub Desktop.
Testing typical bottom-k min hash estimate of Jaccard index versus containment estimate of Jaccard index
import sourmash_lib
import numpy as np
import matplotlib.pyplot as plt
from pybloom import BloomFilter
n = 10000 # sequence length
ksize = 10 # k-mer length
h = 100 # number of hashes in sketch
i_range = range(1, 50000, 500) # range of intersection sizes
#i_range = [10000]
true_jaccards = np.zeros(len(i_range))
estimate_jaccards = np.zeros(len(i_range))
corrected_jaccards = np.zeros(len(i_range))
containment_jaccards = np.zeros(len(i_range))
it = 0
for i_size in i_range:
# Append a common string to two different random strings (true jaccard will be ~ i_size/n)
common_string = ''.join(np.random.choice(['A', 'C', 'T', 'G'], i_size))
seq1 = ''.join(np.random.choice(['A', 'C', 'T', 'G'], n)) + common_string
seq2 = ''.join(np.random.choice(['A', 'C', 'T', 'G'], n/1000)) + common_string # Make seq2 a smaller sequence than seq1
# Calculate exact Jaccard index
kmers1 = set()
kmers2 = set()
for i in xrange(len(seq1) - ksize + 1):
kmers1.add(seq1[i:i+ksize])
for i in xrange(len(seq2) - ksize + 1):
kmers2.add(seq2[i:i+ksize])
true_jaccard = len(kmers1.intersection(kmers2)) / float(len(kmers1.union(kmers2)))
# in case E1.jaccard(E2) is computing containment. It isn't (using commit 6609b3a)
#true_jaccard = len(kmers1.intersection(kmers2)) / float(len(kmers1))
true_jaccards[it] = true_jaccard
# Calculate sourmash estimate of Jaccard index
E1 = sourmash_lib.MinHash(n=h, ksize=ksize)
E2 = sourmash_lib.MinHash(n=h, ksize=ksize)
E1.add_sequence(seq1)
E2.add_sequence(seq2)
estimate_jaccard = E1.jaccard(E2)
estimate_jaccards[it] = estimate_jaccard
# corrected version, making sure that sourmash is doing the right Jaccard estimate
E1_mins = set(E1.get_mins())
E2_mins = set(E2.get_mins())
corrected_estimate_jaccard = len(set(sorted(list(E1_mins.union(E2_mins)))[0:h]).intersection(
E1_mins.intersection(E2_mins))) / float(h)
corrected_jaccards[it] = corrected_estimate_jaccard
# Containment version. Sourmash doesn't return the actual k-mers that were hashed to the minimum values
# So instead I need to hash every one of the k-mers and treat this as the other set
seq1_kmers_minhash = sourmash_lib.MinHash(n=len(kmers1), ksize=ksize)
for kmer in kmers1:
seq1_kmers_minhash.add(kmer)
# This is playing the role of a bloom filter, where I can be confident that seq1_kmer_hashes
# contains all the elements of seq1 (playing the role of the metagenome here)
seq1_kmers_hashes = set(seq1_kmers_minhash.get_mins())
# get the mins from the genome
seq2_kmers_hashes = E2_mins
#containment_est = len(test_kmers_hashes.intersection(seq1_kmers_hashes)) / float(h)
# Bloom filter approach
f = BloomFilter(capacity=len(seq1_kmers_hashes), error_rate=0.001)
for val in seq1_kmers_hashes:
f.add(val)
int_est = 0
for val in seq2_kmers_hashes:
if val in f:
int_est += 1
containment_est = int_est / float(h)
containment_est_jaccard = \
len(kmers2) * containment_est / \
(len(kmers2) + len(kmers1) - len(kmers2) * containment_est) # could use Hyperloglog here to estimate len(kmers1)
# Using the built in count function (over-estimate?)
#containment_est_jaccard = \
# len(kmers2) * containment_est / \
# (len(kmers2) + f.count - len(kmers2) * containment_est)
containment_jaccards[it] = containment_est_jaccard
it += 1
differences = true_jaccards - estimate_jaccards
sorted_true = sorted(true_jaccards)
sorted_estimates = np.array([x for (y, x) in sorted(zip(true_jaccards, estimate_jaccards), key=lambda pair: pair[0])])
sorted_differences = sorted_true - sorted_estimates
plt.figure()
plt.plot(sorted_true, sorted_differences)
axes = plt.gca()
axes.set_ylim([np.min(plt.yticks()[0])*1.5, np.max(plt.yticks()[0])*1.5])
plt.title('True - estimate Jaccard index')
plt.text(0, 0, 'Underestimate', rotation=90, horizontalalignment='center', verticalalignment='bottom', multialignment='center', color='b', fontsize=14)
plt.text(0, 0, 'Overestimate', rotation=90, horizontalalignment='center', verticalalignment='top', multialignment='center', color='r', fontsize=14)
plt.axhline(0, color='black', linestyle='dashed', linewidth=2)
plt.ylabel('Difference')
plt.xlabel('True Jaccard index')
plt.savefig('Differences.png')
# Do a true vs sourmash estimate plot
plt.figure()
f, ax = plt.subplots()
ax.plot([0, 1], [0, 1], ls="--", c=".3")
ax.plot(sorted_true, sorted_estimates)
plt.ylabel('Estimate Jaccard')
plt.xlabel('True Jaccard')
plt.title('Typical Jaccard estimate')
plt.savefig('TrueVsEstimate.png')
# Do a relative error plot
plt.figure()
plt.plot(sorted_true, sorted_differences / sorted_true)
axes = plt.gca()
plt.axhline(0, color='black', linestyle='dashed', linewidth=2)
plt.ylabel('Relative error')
plt.xlabel('True Jaccard index')
plt.savefig('RelativeError.png')
plt.figure()
n, bins, patches = plt.hist(differences, 50, normed=1, facecolor='green', alpha=0.75)
plt.axvline(0, color='b', linestyle='dashed', linewidth=2)
plt.axvline(np.mean(differences), color='black', linestyle='dashed', linewidth=2)
plt.title('Histogram of (true - estimate) Jaccard index\n Mean: %f' % np.mean(differences))
plt.text(0, max(plt.yticks()[0])-1, 'Underestimate', rotation=0, horizontalalignment='left', verticalalignment='top', multialignment='left', color='b', fontsize=14)
plt.text(plt.xticks()[0][1], max(plt.yticks()[0])-1, 'Overestimate', rotation=0, horizontalalignment='left', verticalalignment='top', multialignment='left', color='r', fontsize=14)
plt.xlabel('Difference')
plt.savefig('Histogram.png')
plt.figure()
corrected_differences = true_jaccards - corrected_jaccards
n, bins, patches = plt.hist(corrected_differences, 50, normed=1, facecolor='green', alpha=0.75)
plt.axvline(0, color='b', linestyle='dashed', linewidth=2)
plt.axvline(np.mean(corrected_differences), color='black', linestyle='dashed', linewidth=2)
plt.title('Histogram of (true - corrected estimate) Jaccard index\n Mean: %f' % np.mean(corrected_differences))
plt.text(0, max(plt.yticks()[0])-1, 'Underestimate', rotation=0, horizontalalignment='left', verticalalignment='top', multialignment='left', color='b', fontsize=14)
plt.text(plt.xticks()[0][1], max(plt.yticks()[0])-1, 'Overestimate', rotation=0, horizontalalignment='left', verticalalignment='top', multialignment='left', color='r', fontsize=14)
plt.xlabel('Difference')
plt.savefig('CorrectedHistogram.png')
# Containment guys
plt.figure()
sorted_true = sorted(true_jaccards)
sorted_containment_estimates = np.array([x for (y, x) in sorted(zip(true_jaccards, containment_jaccards), key=lambda pair: pair[0])])
containment_differences = sorted_true - sorted_containment_estimates
#containment_differences = true_jaccards - containment_jaccards
n, bins, patches = plt.hist(containment_differences, 50, normed=1, facecolor='green', alpha=0.75)
plt.axvline(0, color='b', linestyle='dashed', linewidth=2)
plt.axvline(np.mean(containment_differences), color='black', linestyle='dashed', linewidth=2)
plt.title('Histogram of (true - corrected estimate) Jaccard index\n Mean: %f' % np.mean(containment_differences))
plt.text(0, max(plt.yticks()[0])-1, 'Underestimate', rotation=0, horizontalalignment='left', verticalalignment='top', multialignment='left', color='b', fontsize=14)
plt.text(plt.xticks()[0][1], max(plt.yticks()[0])-1, 'Overestimate', rotation=0, horizontalalignment='left', verticalalignment='top', multialignment='left', color='r', fontsize=14)
plt.xlabel('Difference')
plt.savefig('ContainmentHistogram.png')
# Do a true vs containment estimate plot
plt.figure()
f, ax = plt.subplots()
ax.plot([0, 1], [0, 1], ls="--", c=".3")
ax.plot(sorted_true, sorted_containment_estimates)
plt.ylabel('Estimate Jaccard')
plt.xlabel('True Jaccard')
plt.title('Jaccard estimate via containment')
plt.savefig('ContainmentTrueVsEstimate.png')
# Do a relative error plot
plt.figure()
plt.plot(sorted_true, containment_differences / sorted_true)
axes = plt.gca()
plt.axhline(0, color='black', linestyle='dashed', linewidth=2)
plt.ylabel('Relative error')
plt.xlabel('True Jaccard index')
plt.savefig('ContainmentRelativeError.png')
@dkoslicki
Copy link
Author

And here's the important output:
Typical approach histogram:
histogram
Typical approach true vs. estimate:
truevsestimate

Containment histogram (depicting that the estimate it biased):
containmenthistogram
Containment approach true vs estimate (depicting that the variance and overall accuracy is much improved):
containmenttruevsestimate

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment