Skip to content

Instantly share code, notes, and snippets.

@nikkisharma536
Last active December 29, 2019 00:29
Show Gist options
  • Save nikkisharma536/d7b3cb99162829ecca843bba2a2f263e to your computer and use it in GitHub Desktop.
Save nikkisharma536/d7b3cb99162829ecca843bba2a2f263e to your computer and use it in GitHub Desktop.
from keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
#Optimiser
adam = k.optimizers.Adam(lr=0.0005, beta_1=0.9, beta_2=0.999)
# Compile model
model.compile(optimizer=adam, loss=crf.loss_function, metrics=[crf.accuracy, 'accuracy'])
model.summary()
# Saving the best model only
filepath="ner-bi-lstm-td-model-{val_accuracy:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
# Fit the best model
history = model.fit(X, np.array(y), batch_size=256, epochs=10, validation_split=0.2, verbose=1, callbacks=callbacks_list)
# Plot the graph
plt.style.use('ggplot')
def plot_history(history):
accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
x = range(1, len(acc) + 1)
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(x, acc, 'b', label='Training acc')
plt.plot(x, val_acc, 'r', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(x, loss, 'b', label='Training loss')
plt.plot(x, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plot_history(history)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment