Skip to content

Instantly share code, notes, and snippets.

@RaphaelMeudec
Last active October 25, 2020 10:12
Show Gist options
  • Save RaphaelMeudec/74d7889e0dea467b0d8107c64792ce8d to your computer and use it in GitHub Desktop.
Save RaphaelMeudec/74d7889e0dea467b0d8107c64792ce8d to your computer and use it in GitHub Desktop.
Create a simple tf.data Dataset for an image deblurring task
from pathlib import Path
import tensorflow as tf
def select_patch(sharp, blur, patch_size_x, patch_size_y):
"""
Select a patch on both sharp and blur images at the same localization.
Args:
sharp (tf.Tensor): Tensor for the sharp image
blur (tf.Tensor): Tensor for the blur image
patch_size_x (int): Size of patch along x axis
patch_size_y (int): Size of patch along y axis
Returns:
Tuple[tf.Tensor, tf.Tensor]: Tuple of tensors with shape (patch_size_x, patch_size_y, 3)
"""
stack = tf.stack([sharp, blur], axis=0)
patches = tf.image.random_crop(stack, size=[2, patch_size_x, patch_size_y, 3])
return (patches[0], patches[1])
class TensorflowDatasetLoader:
def __init__(self, dataset_path, batch_size=4, patch_size=(256, 256), n_epochs=10, n_images=None):
# List all images paths
sharp_images_paths = [str(path) for path in Path(dataset_path).glob("*/sharp/*.png")]
if n_images is not None:
sharp_images_paths = sharp_images_paths[0:n_images]
# Generate corresponding blurred images paths
blur_images_paths = [path.replace("sharp", "blur") for path in sharp_images_paths]
# Load sharp and blurred images
sharp_dataset = tf.data.Dataset.from_tensor_slices(sharp_images_paths).map(
lambda path: self.load_image(path, dtype),
)
blur_dataset = tf.data.Dataset.from_tensor_slices(blur_images_paths).map(
lambda path: self.load_image(path, dtype),
)
dataset = tf.data.Dataset.zip((sharp_dataset, blur_dataset))
# Select the same patch on the sharp image and its corresponding blurred
dataset = dataset.map(
lambda sharp_image, blur_image: select_patch(
sharp_image, blur_image, patch_size[0], patch_size[1]
)
)
# Define dataset characteristics (batch_size, number_of_epochs, shuffling)
dataset = dataset.batch(batch_size)
dataset = dataset.shuffle(buffer_size=50)
dataset = dataset.repeat()
self.dataset = dataset
@staticmethod
def load_image(image_path, dtype):
image = tf.io.read_file(image_path)
image = tf.image.decode_png(image, channels=3)
image = tf.image.convert_image_dtype(image, dtype)
image = (image - 0.5) * 2
return image
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment