Skip to content

Instantly share code, notes, and snippets.

@coreyjs
Created February 17, 2021 02:09
Show Gist options
  • Save coreyjs/17e565ca25c48b2c3c0198093f534ee8 to your computer and use it in GitHub Desktop.
Save coreyjs/17e565ca25c48b2c3c0198093f534ee8 to your computer and use it in GitHub Desktop.
encode text
# encode the text and map each character to an integer and vice versa
# we create two dictionaries:
# 1. int2char, which maps integers to characters
# 2. char2int, which maps characters to unique integers
chars = tuple(set(text))
int2char = dict(enumerate(chars))
char2int = {ch: ii for ii, ch in int2char.items()}
# encode the text
encoded = np.array([char2int[ch] for ch in text])
def one_hot_encode(arr, n_labels):
# Initialize the the encoded array
one_hot = np.zeros((np.multiply(*arr.shape), n_labels), dtype=np.float32)
# Fill the appropriate elements with ones
one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.
# Finally reshape it to get back to the original array
one_hot = one_hot.reshape((*arr.shape, n_labels))
return one_hot
def get_batches(arr, batch_size, seq_length):
'''Create a generator that returns batches of size
batch_size x seq_length from arr.
Arguments
---------
arr: Array you want to make batches from
batch_size: Batch size, the number of sequences per batch
seq_length: Number of encoded chars in a sequence
'''
batch_size_total = batch_size * seq_length
## TODO: Get the number of batches we can make
n_batches = len(arr) // batch_size_total
## TODO: Keep only enough characters to make full batches
arr = arr[:n_batches * batch_size_total]
## TODO: Reshape into batch_size rows
arr = arr.reshape( (batch_size, -1) )
## TODO: Iterate over the batches using a window of size seq_length
for n in range(0, arr.shape[1], seq_length):
# The features
x =
# The targets, shifted by one
y =
yield x, y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment