Skip to content

Instantly share code, notes, and snippets.

@dbalabka
Last active April 22, 2021 11:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dbalabka/9c0a3d88312f6d51af8949e85758aa7a to your computer and use it in GitHub Desktop.
Save dbalabka/9c0a3d88312f6d51af8949e85758aa7a to your computer and use it in GitHub Desktop.
scikits-bootstrap Bootstrapping resampling with Numba and performance testing
import scikits.bootstrap as bootstrap
import numpy as np
import time
import numba
@numba.njit(parallel=True, fastmath=True)
def _calculate_boostrap_mean_stat(data: np.ndarray, n_samples: int) -> np.ndarray:
n = data.shape[0]
stat = np.zeros(n_samples)
for i in numba.prange(n_samples):
stat[i] = np.random.choice(data, n).mean()
return stat
tdata = (np.random.randint(0, 5, 100_000), )
n_samples = 10_000
start = time.time()
bootindices = bootstrap.bootstrap_indices(tdata[0], n_samples)
stat_old = np.array([np.mean(*(x[indices] for x in tdata))
for indices in bootindices])
end = time.time()
print(end - start)
start = time.time()
rng = np.random.default_rng()
stat_new1 = np.array([np.mean(*(rng.choice(x, tdata[0].shape[0]) for x in tdata)) for _ in range(n_samples)])
end = time.time()
print(end - start)
start = time.time()
stat_new2 = _calculate_boostrap_mean_stat(tdata[0], n_samples)
end = time.time()
print(end - start)
print(f'{stat_old.shape} == {stat_new1.shape}')
print(f'{stat_old.shape} == {stat_new2.shape}')
print(f'{stat_old.mean()} == {stat_new1.mean()}')
print(f'{stat_old.mean()} == {stat_new2.mean()}')
assert stat_old.shape == stat_new1.shape
assert stat_old.shape == stat_new2.shape
assert round(stat_old.mean(), 3) == round(stat_new1.mean(), 3)
assert round(stat_old.mean(), 3) == round(stat_new2.mean(), 3)
# Numba debug
# bootstrap._calculate_boostrap_mean_stat.parallel_diagnostics(level=4)
# bootstrap._calculate_boostrap_mean_stat.inspect_types()
@dbalabka
Copy link
Author

22.106656074523926
15.744634866714478
8.127695083618164
(10000,) == (10000,)
(10000,) == (10000,)
1.998046271 == 1.997934705
1.998046271 == 1.9979548339999997

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment