Skip to content

Instantly share code, notes, and snippets.

@BenedictWilkins
Last active October 5, 2020 17:02
Show Gist options
  • Save BenedictWilkins/d07d8dcfc0aa2963d6686619fae2d090 to your computer and use it in GitHub Desktop.
Save BenedictWilkins/d07d8dcfc0aa2963d6686619fae2d090 to your computer and use it in GitHub Desktop.
"""
Computes a one-hot encoding of a multi-dimensional numpy array. Uses numpy advanced indexing...
If anyone has ideas for a more efficient version please let me know!
Example 1:
import numpy as np
x = np.random.randint(0,3,size=(2,2))
y = onehot(x, (2,2,3))
print(x)
print(y)
Example 2:
import numpy as np
x = np.random.randint(0,3,size=(2,2))
y = onehot(x, (2,3,2))
print(x)
print(y)
"""
def onehot(x, shape, dtype=np.uint8):
# https://stackoverflow.com/a/46103129/ @Divakar
def all_idx(idx, axis): # computes the full index given an multi-dimensional index array
grid = np.ogrid[tuple(map(slice, idx.shape))]
grid.insert(axis, idx)
return tuple(grid)
assert len(shape) - len(x.shape) == 1 #one hot should add one more dimension
shape, xshape = list(shape), list(x.shape)
# find onehot dimension
dif = [a == b for i,(a,b) in enumerate(zip(shape, xshape + [-1]))]
axis = dif.index(False) # one hot dimension (first miss-match)
# validate dimensions
xshape.insert(axis, 1)
check = [int(a == b) for a,b in zip(shape, xshape)]
assert sum(check) == len(shape) - 1 # dimensions should match in all but 1 place
# compute one-hot array
idx = all_idx(x, axis)
r = np.zeros(shape, dtype=dtype)
r[idx] = 1
return r
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment