Skip to content

Instantly share code, notes, and snippets.

@strubell
Created June 2, 2019 14:01
Show Gist options
  • Save strubell/adc680aaf789ac0b6607deaeaa941451 to your computer and use it in GitHub Desktop.
Save strubell/adc680aaf789ac0b6607deaeaa941451 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
def plot(n):
# x = np.random.rand(n)
x = np.random.normal(size=n)**2
y = np.arange(0., 5., 0.01)
augmenteds = [np.array(list(x) + [y_i]) for y_i in y]
augmented_means = np.array([np.mean(a) for a in augmenteds])
augmented_stds = np.array([np.std(a) for a in augmenteds])
diffs = augmented_means - augmented_stds
# diffs = augmented_means/augmented_stds
p = plt.scatter(y, diffs)
c = p.get_facecolor()[0]
print(c)
# the max value we plotted
m = np.argmax(diffs)
plt.axvline(x=y[m], color=c, linestyle='--')
# our estimated max value
plt.axvline(x=np.mean(x)+np.std(x), color=c)
# these lines get closer to each other because with more samples,
# the closer our estimates of the sample mean and std become to the
# true mean and std.
return p
vals = [5, 10, 100]
plots = [plot(v) for v in vals]
plt.xlabel('y')
plt.ylabel('mean - variance')
plt.legend(plots, vals)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment