Skip to content

Instantly share code, notes, and snippets.

@RaphaelMeudec
Created January 14, 2020 14:56
Show Gist options
  • Save RaphaelMeudec/15e940c8645e2a92d49502f126b9b182 to your computer and use it in GitHub Desktop.
Save RaphaelMeudec/15e940c8645e2a92d49502f126b9b182 to your computer and use it in GitHub Desktop.
How to create a k-way n-shot tf.data.Dataset
def build_k_way_n_shot_dataset(annotations, n_shot, k_way, classes=None, to_categorical=True, training=True):
"""Build a dataset where each batch contains only N elements of K classes among all classes"""
# Prepare a dataframe with "image_path", "x1", "x2", "y1", "y2" columns
annotations = annotations.assign(label=pd.Categorical(annotations.label, categories=classes))
# Prepare labels as one hot vectors
targets = annotations.label.cat.codes
if to_categorical:
targets = (
pd.get_dummies(targets)
.reindex(list(range(len(targets.unique()))), axis=1)
.fillna(0)
)
num_classes = len(targets.columns)
batch_size = n_shot * k_way
def load_image_and_crop(image_path, x1, y1, x2, y2):
"""Load an image and crop at given positions"""
image = tf.io.read_file(image_path)
image = tf.image.decode_png(image)
image = (tf.image.convert_image_dtype(image, tf.float32) - 0.5) * 2
image = tf.image.crop_to_bounding_box(image, y1, x1, y2 - y1, x2 - x1)
image = tf.image.resize_with_crop_or_pad(image, 224, 224)
return image
def data_aug(image):
"""Data augmentations examples"""
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
return image
def build_datasets_for_class(annotations, targets, index_class):
"""Build a dataset restricted to a given class"""
print(f"Building for {index_class}")
# Filter all annotations to select only those from selected class
class_targets = targets[targets[index_class] > 0]
class_annotations = annotations.loc[class_targets.index]
# Create dataset
dataset = tf.data.Dataset.from_tensor_slices((
class_annotations["image_name"],
class_annotations["x1"],
class_annotations["y1"],
class_annotations["x2"],
class_annotations["y2"],
class_targets.values.astype("float32"),
))
dataset = dataset.map(
lambda image_name, x1, y1, x2, y2, target: (load_image_and_crop(image_name, x1, y1, x2, y2), target),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
if training:
dataset = dataset.cache()
dataset = dataset.map(
lambda image, target: (data_aug(image), target),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
return dataset
# Create all filtered dataset
datasets_by_class = [build_datasets_for_class(annotations, targets, index_class=index_class).repeat() for index_class in targets.columns]
# Create the choice dataset that defines in which dataset we want to pick elements
# Typically, if n_shots=4, it will produce something like:
# [1, 1, 1, 1, 12, 12, 12, 12, 37, 37, 37, 37, 25, 25, 25, 25, ...]
choice_dataset = tf.data.Dataset.range(num_classes).shuffle(buffer_size=num_classes).repeat().interleave(
lambda index: tf.data.Dataset.from_tensors(index).repeat(n_shot),
cycle_length=1,
block_length=n_shot,
)
# Create the final dataset with choose_from_datasets: it picks elements according to the index generated in the choice_dataset
dataset = tf.data.experimental.choose_from_datasets(datasets_by_class, choice_dataset).batch(batch_size)
return dataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment