Skip to content

Instantly share code, notes, and snippets.

@dbalabka
Last active May 2, 2021 08:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dbalabka/706dd9dcb5fd7e97a0136321c7b87364 to your computer and use it in GitHub Desktop.
Save dbalabka/706dd9dcb5fd7e97a0136321c7b87364 to your computer and use it in GitHub Desktop.
Bootstrapping hypothesis testing of mean equality using Efron's algorithm
from math import sqrt
from typing import Tuple
import numpy as np
import time
import numba
from scipy.stats import ttest_ind
@numba.njit(parallel=True, fastmath=True, nogil=True)
def compare_mean(z: np.ndarray, y: np.ndarray, n_samples: int = 10_000) -> Tuple[np.ndarray, float, float]:
t_obs = _calculate_mean_diff_statistics(z, y)
z_mean = z.mean()
y_mean = y.mean()
x_mean = np.concatenate((z, y)).mean()
z_ = np.add(np.subtract(z, z_mean), x_mean)
y_ = np.add(np.subtract(y, y_mean), x_mean)
t = np.zeros(n_samples)
for i in numba.prange(n_samples):
zz_ = np.random.choice(z_, z_.shape[0])
yy_ = np.random.choice(y_, y_.shape[0])
t[i] = _calculate_mean_diff_statistics(zz_, yy_)
return t, t_obs, float(np.sum(np.greater_equal(t, t_obs)) / n_samples)
@numba.njit(parallel=True, fastmath=True, nogil=True)
def _calculate_mean_diff_statistics(z_: np.ndarray, y_: np.ndarray) -> float:
z_mean_ = z_.mean()
y_mean_ = y_.mean()
return (z_mean_ - y_mean_) / sqrt(
(np.sum(np.power(np.subtract(z_, z_mean_), 2)) / (z_.shape[0] - 1)) / z_.shape[0] +
(np.sum(np.power(np.subtract(y_, y_mean_), 2)) / (y_.shape[0] - 1)) / y_.shape[0]
)
size = 100_000
# mean and standard deviation
z = np.random.normal(0.05, 0.5, size)
y = np.random.normal(0, 0.5, size)
n_samples = 10_000
start = time.time()
t = compare_mean(z, y, n_samples)
print(f'p-value: {t[2]}')
end = time.time()
print(end - start)
start = time.time()
t = ttest_ind(z, y)
print(f'p-value: {t[1]}')
end = time.time()
print(end - start)
# Numba debug
# bootstrap.compare_dist.parallel_diagnostics(level=4)
# bootstrap.compare_dist.inspect_types()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment