Skip to content

Instantly share code, notes, and snippets.

@Chiraagkv
Last active August 27, 2021 14:32
Show Gist options
  • Save Chiraagkv/19753b5b09909e82ad6f0f6c68c756ac to your computer and use it in GitHub Desktop.
Save Chiraagkv/19753b5b09909e82ad6f0f6c68c756ac to your computer and use it in GitHub Desktop.
def plot_histories(*args):
'''
Plots histories
'''
epochs1 = np.arange(1, len(pd.DataFrame(args[0].history)) + 1)
fig, ax = plt.subplots(len(args), 2, figsize=(16, 7 * len(args)))
fig.suptitle('Comparing Model Histories')
for i in range(len(args)):
loss1 = pd.DataFrame(args[i].history["loss"])
ax[i, 0].set_title(f"Model {i + 1} Loss")
ax[i, 0].plot(np.arange(1, len(pd.DataFrame(args[i].history)) + 1), args[i].history["loss"], np.arange(1, len(pd.DataFrame(args[i].history)) + 1), args[i].history["val_loss"])
ax[i, 0].set_xlabel('Epochs')
ax[i, 0].legend(["Train Loss", "Val Loss"])
ax[i, 1].set_title(f"Model {i + 1} Accuracy")
ax[i, 1].plot(np.arange(1, len(pd.DataFrame(args[i].history)) + 1), args[i].history["accuracy"], np.arange(1, len(pd.DataFrame(args[i].history)) + 1), args[i].history["val_accuracy"])
ax[i, 1].set_xlabel('Epochs')
ax[i, 1].legend(["Train Accuracy", "Val Accuracy"]);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment