Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
List of helpers to generate images for tensorboard
def plot_confusion_matrix(cm, class_names=class_names):
"""
Returns a matplotlib figure containing the plotted confusion matrix.
Args:
cm (array, shape = [n, n]): a confusion matrix of integer classes
class_names (array, shape = [n]): String names of the integer classes
"""
figure = plt.figure(figsize=(8, 8))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title("Confusion matrix")
plt.colorbar()
tick_marks = np.arange(len(class_names))
plt.xticks(tick_marks, class_names, rotation=45)
plt.yticks(tick_marks, class_names)
# Compute the labels from the normalized confusion matrix.
labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)
# Use white text if squares are dark; otherwise black.
threshold = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
color = "white" if cm[i, j] > threshold else "black"
plt.text(j, i, labels[i, j], horizontalalignment="center", color=color)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
return figure
def plot_to_image(figure):
"""Converts the matplotlib plot specified by 'figure' to a PNG image and
returns it. The supplied figure is closed and inaccessible after this call."""
# Save the plot to a PNG in memory.
buf = io.BytesIO()
plt.savefig(buf, format='png')
# Closing the figure prevents it from being displayed directly inside
# the notebook.
plt.close(figure)
buf.seek(0)
# Convert PNG buffer to TF image
image = tf.image.decode_png(buf.getvalue(), channels=4)
# Add the batch dimension
image = tf.expand_dims(image, 0)
return image
def image_grid(labels, preds, miss_class, class_names, images):
# Create a figure to contain the plot.
figure = plt.figure(figsize=(10,3*len(miss_class)))
i = 0
for idx in miss_class:
# Start next subplot.
label=f"Predicted: {class_names[preds[idx]]}, Acc label: {class_names[labels[idx]]}"
plt.subplot(len(miss_class), 2, i + 1, title=label)
plt.xticks([])
plt.yticks([])
plt.grid(False)
img = tf.cast(tf.reshape(images[idx], [IMAGE_WIDTH, IMAGE_WIDTH, 3])*255, tf.uint8)
plt.imshow(img, cmap=plt.cm.binary)
i += 1
return figure
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment