Skip to content

Instantly share code, notes, and snippets.

@khalido
Last active December 10, 2019 05:34
Show Gist options
  • Save khalido/f877cd9f2a84faba82e0c815e89f5726 to your computer and use it in GitHub Desktop.
Save khalido/f877cd9f2a84faba82e0c815e89f5726 to your computer and use it in GitHub Desktop.
[keras] misc keras things #keras
# see https://github.com/stared/livelossplot for live plots, or tensorboard
# but for simple stuff the below is good enough
def plot_history(history, log=True):
"""takes in a keras history object and plots train and val loss and accuracy"""
# dict which stores train & val accuracy and losses over epochs
hist = history.history
fig, (ax, ax2) = plt.subplots(1,2, figsize=(13,6))
fig.suptitle("Train/Val History of Loss and Accuracy", fontsize=15)
# setup subplots
ax.set_title("Loss against Epochs")
ax.set_xlabel('number of epoches', fontsize = 13)
if log:
ax.set_yscale('log')
ax.set_ylabel('loss (log)', fontsize = 13)
else:
ax.set_ylabel('loss', fontsize = 13)
ax2.set_title("Accuracy against Epochs")
ax2.set_xlabel('number of epoches', fontsize = 13)
ax2.set_ylabel('Accuracy', fontsize = 13);
# plots train loss and accuracy
x = history.epoch
ax.plot(x, hist["loss"], label=f"train_loss")
ax2.plot(x, hist["accuracy"], label=f"train_accuracy")
# plots val data - in try as we might have no validation data
try:
ax2.plot(x, hist["val_accuracy"], label=f"val_accuracy")
ax.plot(x, hist["val_loss"], label="val_loss")
except:
pass
ax.legend(); ax2.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment