Skip to content

Instantly share code, notes, and snippets.

@doleron
Last active April 9, 2023 13:03
Show Gist options
  • Save doleron/16ea80c26b30a6148fb129aee1bde789 to your computer and use it in GitHub Desktop.
Save doleron/16ea80c26b30a6148fb129aee1bde789 to your computer and use it in GitHub Desktop.
def checking_wrong_predictions():
for images, labels in validation_ds:
predictions = model.predict(images, verbose = 0)
for i in range(VALIDATION_BATCH_SIZE):
img_A = (images[0][i].numpy()*255).astype("uint8")
img_B = (images[1][i].numpy()*255).astype("uint8")
label = labels[i].numpy()
prediction = predictions[i][0]
if (label > 0 and prediction >= 0.5) or (label < 1 and prediction < 0.5):
fig, axs = plt.subplots(1, 2)
axs[0].imshow(img_A)
axs[0].axis("off")
axs[1].imshow(img_B)
axs[1].axis("off")
type = "POSITIVE"
if prediction >= 0.5:
type = "NEGATIVE"
title = "Distance = {:.4f} (FALSE {})".format(prediction, type)
fig.suptitle(title, fontsize=15)
plt.show()
checking_wrong_predictions()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment