Skip to content

Instantly share code, notes, and snippets.

@ihincks
Created June 12, 2018 20:00
Show Gist options
  • Save ihincks/734fc63a0a8335210aa26598488409cc to your computer and use it in GitHub Desktop.
Save ihincks/734fc63a0a8335210aa26598488409cc to your computer and use it in GitHub Desktop.
Function that generalizes np.random.choice to n-D arrays
def random_choice(mat, n_samples=None, axis=0):
"""
Replaces the given axis with n_samples random choices (with replacement)
of values already along that axis.
A=np.arange(15).reshape(3,5)
print(A)
print(random_choice(A, n_samples=6, axis=1))
print(random_choice(A, axis=0))
>>> [[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]]
[[ 1 2 3 4 2 2]
[ 7 9 9 7 5 6]
[11 10 11 10 10 14]]
[[ 5 6 7 3 4]
[ 5 11 2 8 14]
[ 0 6 7 13 4]]
:param array-like mat: The input matrix.
:param n_samples: How many samples to draw (with replacement) along the
given axis. If `None`, the shape is not changed.
:param int axis: Which axis to resample along.
"""
shape= list(mat.shape)
axis_size = shape[axis]
rest_size = mat.size // axis_size
n_samples = axis_size if n_samples is None else n_samples
perm = list(range(0,axis)) + list(range(axis+1, mat.ndim)) + [axis]
choices = sp.random.randint(0, axis_size, size=[rest_size, n_samples])
mat = mat.transpose(perm).reshape(rest_size, axis_size)
mat = sp.take(mat, choices + sp.arange(rest_size)[:,sp.newaxis] * axis_size)
shape.pop(axis); shape += [n_samples]
return mat.reshape(shape).transpose(sp.argsort(perm))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment