Skip to content

Instantly share code, notes, and snippets.

@shpigunov
Created May 7, 2019 08:30
Show Gist options
  • Save shpigunov/8b3221a74519834ae37b88f6f7607e21 to your computer and use it in GitHub Desktop.
Save shpigunov/8b3221a74519834ae37b88f6f7607e21 to your computer and use it in GitHub Desktop.
# Enable PlaidML Backend
import os
os.environ["KERAS_BACKEND"] = "plaidml.keras.backend"
# Small LSTM Network to Generate Text for Alice in Wonderland
import numpy
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils
# load ascii text and covert to lowercase
# Basically, "wonderland.txt" can be any large ascii text
filename = "wonderland.txt"
raw_text = open(filename).read()
raw_text = raw_text.lower()
# create mapping of unique chars to integers
chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c in enumerate(chars))
# summarize the loaded data
n_chars = len(raw_text)
n_vocab = len(chars)
print "Total Characters: ", n_chars
print "Total Vocab: ", n_vocab
# prepare the dataset of input to output pairs encoded as integers
seq_length = 100
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
seq_in = raw_text[i:i + seq_length]
seq_out = raw_text[i + seq_length]
dataX.append([char_to_int[char] for char in seq_in])
dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
print "Total Patterns: ", n_patterns
# reshape X to be [samples, time steps, features]
X = numpy.reshape(dataX, (n_patterns, seq_length, 1))
# normalize
X = X / float(n_vocab)
# one hot encode the output variable
y = np_utils.to_categorical(dataY)
# define the LSTM model
model = Sequential()
model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))
model.add(Dropout(0.2))
model.add(Dense(y.shape[1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')
# define the checkpoint
filepath="weights-improvement-{epoch:02d}-{loss:.4f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')
callbacks_list = [checkpoint]
# fit the model
# When used with the PlaidML backend, the .fit() method runtime output displays loss as `nan`. Moreover, at the end of each epoch, the output is `loss did not improve from inf`.
model.fit(X, y, epochs=20, batch_size=128, callbacks=callbacks_list)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment