Skip to content

Instantly share code, notes, and snippets.

@adrialuzllompart
Last active October 10, 2018 11:41
Show Gist options
  • Save adrialuzllompart/264c55ad90f1ea00d245c682458996da to your computer and use it in GitHub Desktop.
Save adrialuzllompart/264c55ad90f1ea00d245c682458996da to your computer and use it in GitHub Desktop.
def permutation_test(control, treatment, alpha, r=1000):
"""
Runs a permutation test to check whether the difference in means
between control and treatment is statistically significant.
Parameters:
control: pd.Series
A pandas series with all the control (A) observations
treatment: pd.Series
A pandas series with all the treatment (B) observations
r: int
Number of iterations for the permutation test
Returns:
Prints the p-value and plots the histogram of the random differences.
"""
# record how many observations we have in each group
n_control = len(control)
n_treatment = len(treatment)
# record the observed difference
observed_diff = treatment.mean() - control.mean()
# combine the results from both groups
both = pd.concat([control, treatment], axis=0)
# run a permutation test r times
deltas = []
for i in range(r):
sample_control = both.sample(n=n_control)
sample_treatment = both.loc[~both.index.isin(sample_control.index)].copy()
random_diff = sample_treatment.mean() - sample_control.mean()
deltas.append(random_diff)
deltas_df = pd.DataFrame(deltas)
# print the p-value
# note that we double the p-value because it's a two-sided test
if observed_diff > 0:
p_value = np.mean(deltas_df > observed_diff)[0] * 2
elif observed_diff < 0:
p_value = np.mean(deltas_df < observed_diff)[0] * 2
print('P-value = {}'.format(p_value))
# small p-values (< alpha / 2) suggest the observed difference is statisticially significant
if p_value < alpha / 2:
print('The observed difference is statistically significant.')
elif p_value >= alpha / 2:
print('The observed difference is likely due to chance.')
# plot the histogram of random differences
# together with the observed difference
fig, ax = plt.subplots(figsize=(9,6))
sns.distplot(deltas_df, ax=ax)
ax.axvline(observed_diff, linestyle=':', linewidth=2.5, alpha=0.5, c='k')
ax.axvline(min_effect, linestyle=':', linewidth=2.5, alpha=0.5, c='g')
ax.tick_params(labelsize=12, length=0)
ax.set_yticklabels('')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment