Skip to content

Instantly share code, notes, and snippets.

@zaccharieramzi
Last active November 23, 2021 16:06
Show Gist options
  • Save zaccharieramzi/f7dd5f0e34691d0a987e1e50b694ac03 to your computer and use it in GitHub Desktop.
Save zaccharieramzi/f7dd5f0e34691d0a987e1e50b694ac03 to your computer and use it in GitHub Desktop.
This keras callback outputs the results of an image-to-image model to tensorboard. In this case it was done for a denoiser, but it could also be implemented for segmentation, super-resolution, ...
"""Inspired by https://stackoverflow.com/a/49363251/4332585"""
import io
from keras.callbacks import Callback
import numpy as np
from PIL import Image
from skimage.util import img_as_ubyte
import tensorflow as tf
def make_image(tensor):
"""
Convert an numpy representation image to Image protobuf.
Copied from https://github.com/lanpa/tensorboard-pytorch/
"""
_, height, width, channel = tensor.shape
tensor = tensor[0]
tensor_normalized = tensor - tensor.min()
tensor_normalized /= tensor_normalized.max()
tensor_normalized = img_as_ubyte(tensor_normalized)
tensor_squeezed = np.squeeze(tensor_normalized)
image = Image.fromarray(tensor_squeezed)
output = io.BytesIO()
image.save(output, format='PNG')
image_string = output.getvalue()
output.close()
summary = tf.Summary.Image(
height=height,
width=width,
colorspace=channel,
encoded_image_string=image_string,
)
return summary
class TensorBoardImage(Callback):
def __init__(self, log_dir, image, noisy_image):
super().__init__()
self.log_dir = log_dir
self.image = image
self.noisy_image = noisy_image
def set_model(self, model):
self.model = model
self.writer = tf.summary.FileWriter(self.log_dir, filename_suffix='images')
def on_train_begin(self, _):
self.write_image(self.image, 'Original Image', 0)
def on_train_end(self, _):
self.writer.close()
def write_image(self, image, tag, epoch):
image = make_image(image)
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, image=image)])
self.writer.add_summary(summary, epoch)
self.writer.flush()
def on_epoch_end(self, epoch, logs={}):
denoised_image = self.model.predict_on_batch(self.noisy_image)
self.write_image(denoised_image, 'Denoised Image', epoch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment