Skip to content

Instantly share code, notes, and snippets.

@f-rumblefish
Last active August 4, 2017 09:56
Show Gist options
  • Save f-rumblefish/083e5ff6419e832c045c76de812cef53 to your computer and use it in GitHub Desktop.
Save f-rumblefish/083e5ff6419e832c045c76de812cef53 to your computer and use it in GitHub Desktop.
Recurrent Neural Network (LSTM) in Keras
# Import library in Keras 1.2.0
from keras.layers.recurrent import SimpleRNN, GRU, LSTM
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.callbacks import EarlyStopping
from keras.utils.visualize_util import plot
# Define parameters
HIDDEN_SIZE = 128
BATCH_SIZE = 10
NUM_EPOCHS = 500
SEQLEN = 10
# Define the network
model = Sequential()
model.add(LSTM(HIDDEN_SIZE,
return_sequences=False,
input_shape=(SEQLEN, nb_chars),
unroll=True))
model.add(Dense(nb_chars))
model.add(Activation("softmax"))
model.compile(loss="categorical_crossentropy", optimizer="rmsprop")
plot(model, to_file='WhatILearnFromTheGodfatherLSTM.png', show_shapes=True)
# Train the network
early_stop = EarlyStopping(monitor='loss', patience=10, verbose=0)
log = model.fit(X, y, batch_size=BATCH_SIZE, nb_epoch=NUM_EPOCHS, verbose=0, callbacks=[early_stop])
# Plot the log
log_history = log.history['loss']
plt.figure(facecolor='white')
plt.plot(np.arange(len(log_history)), log_history, marker='o', color='b', label='loss')
plt.xlabel("Epoch")
plt.ylabel("Loss History")
plt.title("LSTM - 500 Epochs with Early Stopping")
log_size = len(log_history)
log_stop = log_history[log_size-1]
log_text = "Early Stopping (%d, %0.3f)"%(log_size, log_stop)
plt.text(log_size-10, log_stop+0.2, log_text, ha='center', color='b')
plt.grid()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment