Skip to content

Instantly share code, notes, and snippets.

@gabrieldernbach
Last active March 17, 2022 15:02
Show Gist options
  • Save gabrieldernbach/82a7ffcd779f977b063cb00d238fa238 to your computer and use it in GitHub Desktop.
Save gabrieldernbach/82a7ffcd779f977b063cb00d238fa238 to your computer and use it in GitHub Desktop.
import numpy as np
from scipy.stats import chi2_contingency
# example data taken from
# https://en.wikipedia.org/wiki/Chi-squared_test#Example_chi-squared_test_for_categorical_data
X = np.array([
[90, 60, 104, 95],
[30, 50, 51, 20],
[30, 40, 45, 35],
])
# reference implementation
ref_chi2, ref_pval, ref_ddof, ref_expected = chi2_contingency(X)
# introducing the test statistic
def ex(obs):
# joint probability from outer product of marginals
rowsum = obs.sum(0, keepdims=True)
colsum = obs.sum(1, keepdims=True)
return (rowsum * colsum) / obs.sum()
def chi2_statistic(obs):
# difference of observed to marginal expectation
ex_ = ex(obs)
return ((obs - ex_)**2 / ex_).sum()
assert ref_chi2 == chi2_statistic(X)
assert ref_expected == ex(X)
# getting a p-value by sampling the h0 distribution
def sample_h0(x):
# assuming X is a contingency table
assert (x > 0).all()
assert x.dtype == "int"
# construct distribution with same marginals but no dependency.
outer = x.sum(0, keepdims=True) * x.sum(1, keepdims=True)
dist = outer / outer.sum()
# sample from the distribution
thresholds = dist.flatten().cumsum()
samp = np.random.rand(x.sum(), 1)
idx = (samp > thresholds).sum(1)
hist = np.histogram(idx, bins=range(len(thresholds)+1))[0]
return hist.reshape(x.shape)
samps = [sample_h0(X) for _ in range(100_000)]
vals = [chi2_statistic(x) for x in samps]
pval = (np.array(vals) > chi2_statistic(X)).mean()
print(ref_pval, pval)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment