Skip to content

Instantly share code, notes, and snippets.

@kkleidal
Created November 17, 2017 19:55
Show Gist options
  • Save kkleidal/c88e033193edf92d4027943e49b27d96 to your computer and use it in GitHub Desktop.
Save kkleidal/c88e033193edf92d4027943e49b27d96 to your computer and use it in GitHub Desktop.
Matplotlib Image Summaries in Tensorboard
import tensorflow as tf
import io
import matplotlib.pyplot as plt
import numpy as np
import scipy as scipy # Ensure PIL is also installed: pip install pillow
'''
matplotlib_summary code:
Code for generating a tensorflow image summary of a custom matplotlib plot.
Usage: matplotlib_summary(plotting_function, argument1, argument2, ..., name="summary name")
plotting_function is a function which take the matplotlib figure as the first argument and numpy
versions of argument1, ..., argumentn as the additional arguments and draws the matplotlib plot on the figure
matplotlib_summary creates and returns a tensorflow image summary
'''
class MatplotlibSummaryOpFactory:
def __init__(self):
self.counter = 0
def _wrap_pltfn(self, plt_fn):
def plot(*args):
f = plt.figure()
args = [f] + list(args)
plt_fn(*args)
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
im = scipy.misc.imread(buf)
buf.close()
return im
return plot
def __call__(self, plt_fn, *args, name=None):
if name is None:
self.counter += 1
name = "matplotlib-summary_%d" % self.counter
image_tensor = tf.py_func(self._wrap_pltfn(plt_fn), args, tf.uint8)
image_tensor.set_shape([None, None, 4])
return tf.summary.image(name, tf.expand_dims(image_tensor, 0))
matplotlib_summary = MatplotlibSummaryOpFactory()
'''
END matplotlib_summary code
'''
# Example usage:
def plt_mnist(f, digit):
# f is the matplotlib figure
# digit is a numpy version of the argument passed to matplotlib_summary
f.gca().imshow(np.squeeze(digit, -1))
f.gca().set_title("A random MNIST digit")
digit = tf.random_normal([28, 28, 1])
summary = matplotlib_summary(plt_mnist, digit, name="mnist-summary")
all_summaries = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(".")
with tf.Session() as sess:
summ = sess.run(all_summaries)
summary_writer.add_summary(summ, global_step=0)
@kkleidal
Copy link
Author

kkleidal commented Nov 17, 2017

Resultant tensorboard:
example

This script allows you to use custom matplotlib code to make plots shown in tensorboard. This is handy for extending tensorboard to show data not well represented using tensorboard's built-in visualizations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment