Skip to content

Instantly share code, notes, and snippets.

@qianyizhang

qianyizhang/one_hot.py

Last active Jun 6, 2018
Embed
What would you like to do?
numpy one_hot function
import numpy as np
def one_hot(nparray, depth = 0, on_value = 1, off_value = 0):
if depth == 0:
depth = np.max(nparray) + 1
assert np.max(nparray) < depth, "the max index of nparray: {} is larger than depth: {}".format(np.max(nparray), depth)
shape = nparray.shape
out = np.ones((*shape, depth)) * off_value
indices = []
for i in range(nparray.ndim):
tiles = [1] * nparray.ndim
s = [1] * nparray.ndim
s[i] = -1
r = np.arange(shape[i]).reshape(s)
if i > 0:
tiles[i-1] = shape[i-1]
r = np.tile(r, tiles)
indices.append(r)
indices.append(nparray)
out[tuple(indices)] = on_value
return out
def test_one_hot():
a = np.array([1,2,3],[4,5,6])
# array([[[ 0., 1., 0., 0., 0., 0., 0.],
# [ 0., 0., 1., 0., 0., 0., 0.],
# [ 0., 0., 0., 1., 0., 0., 0.]],
#
# [[ 0., 0., 0., 0., 1., 0., 0.],
# [ 0., 0., 0., 0., 0., 1., 0.],
# [ 0., 0., 0., 0., 0., 0., 1.]]])
one_hot(a)
@xychenunc

This comment has been minimized.

Copy link

@xychenunc xychenunc commented Jun 6, 2018

well done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.