Last active
October 10, 2018 11:41
-
-
Save adrialuzllompart/264c55ad90f1ea00d245c682458996da to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def permutation_test(control, treatment, alpha, r=1000): | |
""" | |
Runs a permutation test to check whether the difference in means | |
between control and treatment is statistically significant. | |
Parameters: | |
control: pd.Series | |
A pandas series with all the control (A) observations | |
treatment: pd.Series | |
A pandas series with all the treatment (B) observations | |
r: int | |
Number of iterations for the permutation test | |
Returns: | |
Prints the p-value and plots the histogram of the random differences. | |
""" | |
# record how many observations we have in each group | |
n_control = len(control) | |
n_treatment = len(treatment) | |
# record the observed difference | |
observed_diff = treatment.mean() - control.mean() | |
# combine the results from both groups | |
both = pd.concat([control, treatment], axis=0) | |
# run a permutation test r times | |
deltas = [] | |
for i in range(r): | |
sample_control = both.sample(n=n_control) | |
sample_treatment = both.loc[~both.index.isin(sample_control.index)].copy() | |
random_diff = sample_treatment.mean() - sample_control.mean() | |
deltas.append(random_diff) | |
deltas_df = pd.DataFrame(deltas) | |
# print the p-value | |
# note that we double the p-value because it's a two-sided test | |
if observed_diff > 0: | |
p_value = np.mean(deltas_df > observed_diff)[0] * 2 | |
elif observed_diff < 0: | |
p_value = np.mean(deltas_df < observed_diff)[0] * 2 | |
print('P-value = {}'.format(p_value)) | |
# small p-values (< alpha / 2) suggest the observed difference is statisticially significant | |
if p_value < alpha / 2: | |
print('The observed difference is statistically significant.') | |
elif p_value >= alpha / 2: | |
print('The observed difference is likely due to chance.') | |
# plot the histogram of random differences | |
# together with the observed difference | |
fig, ax = plt.subplots(figsize=(9,6)) | |
sns.distplot(deltas_df, ax=ax) | |
ax.axvline(observed_diff, linestyle=':', linewidth=2.5, alpha=0.5, c='k') | |
ax.axvline(min_effect, linestyle=':', linewidth=2.5, alpha=0.5, c='g') | |
ax.tick_params(labelsize=12, length=0) | |
ax.set_yticklabels('') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment