Last active
July 5, 2022 08:34
-
-
Save ZhengHe-MD/48e633ec49a35b5259908cf468dc073b to your computer and use it in GitHub Desktop.
AB-Permutation-Test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
#constructing an array of data : A/B | |
total_A, succ_A = 3469, 98 | |
total_B, succ_B = 2798, 43 | |
observation_A = np.array([1] * succ_A + [0] * (total_A-succ_A)) | |
observation_B = np.array([1] * succ_B + [0] * (total_B-succ_B)) | |
num_simulations = 10000 | |
def diff(data_A, data_B): | |
'''finds the difference between data_A and data_B''' | |
return np.sum(data_A) - np.sum(data_B) | |
def permutation_sample(data1, data2): | |
"""Generate a permutation sample from two data sets.""" | |
# Concatenate the data sets: data | |
data = np.concatenate((data1,data2)) | |
# Permute the concatenated array: permuted_data | |
permuted_data = np.random.permutation(data) | |
# Split the permuted array into two: perm_sample_1, perm_sample_2 | |
perm_sample_1 = permuted_data[:len(data1)] | |
perm_sample_2 = permuted_data[len(data1):] | |
return perm_sample_1, perm_sample_2 | |
def permutation_replicate(data_1, data_2, func, size=1): | |
'''generating multiple permutation replicates''' | |
perm_replicates = np.empty(size) | |
for i in range(size): | |
#generating permutation samples | |
perm_sample_1, perm_sample_2 = permutation_sample(data_1, data_2) | |
#computing test statistic | |
perm_replicates[i] = func(perm_sample_1, perm_sample_2) | |
return perm_replicates | |
perm_replicates = permutation_replicate(observation_A, observation_B, diff, num_simulations) | |
plt.hist(perm_replicates, bins=list(range(-(succ_A+succ_B),succ_A+succ_B))) | |
plt.axvspan(succ_A-succ_B, succ_A-succ_B+1, color='red', alpha=0.5) | |
plt.show() | |
p_value = np.sum(perm_replicates >= diff(observation_A, observation_B)) / num_simulations | |
print(p_value) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment