Skip to content

Instantly share code, notes, and snippets.

@jhurliman
Created October 1, 2018 23:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jhurliman/e3c7d6a3f0b4382b186430f27d8e1345 to your computer and use it in GitHub Desktop.
Save jhurliman/e3c7d6a3f0b4382b186430f27d8e1345 to your computer and use it in GitHub Desktop.
An implementation of "BEST: Bayesian Estimation Supersedes the t Test" using pymc3
from multiprocessing import cpu_count
import matplotlib
matplotlib.use('Agg', warn=False)
import matplotlib.pyplot as plt # noqa: E402
import numpy as np # noqa: E402
import pymc3 as pm # noqa: E402
import six.moves
# Region of practical equivalence, in milliseconds
ROPE = [-2, 2]
# An implementation of "BEST: Bayesian Estimation Supersedes the t Test" using
# pymc3. See <http://www.indiana.edu/~kruschke/BEST/>
def ab_test(y1, y2):
y = np.concatenate([y1, y2])
stdev = y.std()
# Prior parameters for mean
# mu = mu of [y1, y2], sigma = twice the sigma of [y1, y2]
mu_m = y.mean()
mu_s = stdev * 2
# Prior parameters for standard deviation
# Uniform distribution two magnitudes above and below sigma of [y1, y2]
sigma_low = stdev / 100
sigma_high = stdev * 100
with pm.Model():
a_mean = pm.Normal('a_mean', mu_m, sd=mu_s)
b_mean = pm.Normal('b_mean', mu_m, sd=mu_s)
a_std = pm.Uniform('a_std', lower=sigma_low, upper=sigma_high)
b_std = pm.Uniform('b_std', lower=sigma_low, upper=sigma_high)
nu = pm.Exponential('nu_minus_one', 1 / 29.) + 1
# Use 1/(sigma^2) to work with PyMC3's parameterization of Student's
# t-distribution
pm.StudentT('a', nu=nu, mu=a_mean, lam=a_std**-2, observed=y1)
pm.StudentT('b', nu=nu, mu=b_mean, lam=b_std**-2, observed=y2)
pm.Deterministic('diff of means', b_mean - a_mean)
pm.Deterministic('diff of stds', b_std - a_std)
trace = pm.sample(6000, tune=1000, njobs=1, progressbar=False,
random_seed=list(six.moves.range(cpu_count())))
return trace
def plot_ab_test(name, trace, git_commit, prev_git_commit):
plt.subplots(figsize=(14, 5))
pm.forestplot(trace, rhat=False, varnames=['a_mean', 'b_mean', 'a_std', 'b_std'],
colors='#ff8c00')
rhat = pm.diagnostics.gelman_rubin(trace, varnames=['a_mean', 'b_mean', 'a_std', 'b_std'])
fig = plt.gcf()
ax = plt.gca()
# Checking lack of convergence based on: Brooks, S. P., and A. Gelman. 1997.
# General Methods for Monitoring Convergence of Iterative Simulations.
# Journal of Computational and Graphical Statistics 7: 434-455.
warning_str = ''
for var in rhat:
if rhat[var] >= 1.2:
warning_str += "WARNING: Lack of convergence for " + str(var) + ', with rhat: ' + str(rhat[var]) + "\n"
if warning_str:
plt.text(0.8, 0.2, warning_str.strip(), horizontalalignment='center', verticalalignment='center',
transform=ax.transAxes, bbox={'facecolor': 'red', 'alpha': 0.5, 'pad': 5})
ax.set_xlabel('Time (ms)')
ax1 = fig.add_subplot(111)
pm.plot_posterior(trace, rope=ROPE, varnames=['diff of means'],
ax=ax1, color='#ff8c00')
ax1.set_xlabel('Time (ms)')
plt.subplots_adjust(bottom=1.1, top=2)
truncated_name = name[name.find('/') + 1:]
plt.title('A: ' + prev_git_commit + '| B: ' + git_commit)
plt.suptitle(truncated_name + ", Diff_Means:|B - A|", y=2.1, fontsize=16)
return fig
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment