Last active
December 10, 2019 05:34
-
-
Save khalido/f877cd9f2a84faba82e0c815e89f5726 to your computer and use it in GitHub Desktop.
[keras] misc keras things #keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# see https://github.com/stared/livelossplot for live plots, or tensorboard | |
# but for simple stuff the below is good enough | |
def plot_history(history, log=True): | |
"""takes in a keras history object and plots train and val loss and accuracy""" | |
# dict which stores train & val accuracy and losses over epochs | |
hist = history.history | |
fig, (ax, ax2) = plt.subplots(1,2, figsize=(13,6)) | |
fig.suptitle("Train/Val History of Loss and Accuracy", fontsize=15) | |
# setup subplots | |
ax.set_title("Loss against Epochs") | |
ax.set_xlabel('number of epoches', fontsize = 13) | |
if log: | |
ax.set_yscale('log') | |
ax.set_ylabel('loss (log)', fontsize = 13) | |
else: | |
ax.set_ylabel('loss', fontsize = 13) | |
ax2.set_title("Accuracy against Epochs") | |
ax2.set_xlabel('number of epoches', fontsize = 13) | |
ax2.set_ylabel('Accuracy', fontsize = 13); | |
# plots train loss and accuracy | |
x = history.epoch | |
ax.plot(x, hist["loss"], label=f"train_loss") | |
ax2.plot(x, hist["accuracy"], label=f"train_accuracy") | |
# plots val data - in try as we might have no validation data | |
try: | |
ax2.plot(x, hist["val_accuracy"], label=f"val_accuracy") | |
ax.plot(x, hist["val_loss"], label="val_loss") | |
except: | |
pass | |
ax.legend(); ax2.legend() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment