Skip to content

Instantly share code, notes, and snippets.

@kdubovikov
Created October 8, 2017 10:22
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save kdubovikov/d4e5c688fa771227fdf8c924196a59fe to your computer and use it in GitHub Desktop.
Save kdubovikov/d4e5c688fa771227fdf8c924196a59fe to your computer and use it in GitHub Desktop.
Fast random subset sampling with Cython
%%cython
import numpy as np
cimport numpy as np
cimport cython # so we can use cython decorators
from cpython cimport bool # type annotation for boolean
# disable index bounds checking and negative indexing for speedups
@cython.wraparound(False)
@cython.boundscheck(False)
cdef cython_get_sample(np.ndarray arr, arr_len, n_iter, int sample_size,
bool fast):
cdef int start_idx
if fast:
start_idx = (n_iter * sample_size) % arr_len
if start_idx + sample_size >= arr_len:
np.random.shuffle(arr)
return arr[start_idx:start_idx+sample_size]
else:
return np.random.choice(arr, sample_size, replace=False)
@cython.wraparound(False)
@cython.boundscheck(False)
def cython_collect_samples(np.ndarray arr,
int sample_size,
int n_samples,
bool fast=False):
cdef np.ndarray samples
cdef int arr_len
cdef int sample_len
cdef np.ndarray sample
samples = np.zeros((n_samples + 1, sample_size), np.int64) # allocate all memory in advance
arr_len = len(arr)
for sample_n in range(0, n_samples):
sample = cython_get_sample(arr, arr_len, sample_n,
sample_size,
fast)
samples[sample_n] = sample
return samples
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment