Skip to content

Instantly share code, notes, and snippets.

@dnkirill
Created July 24, 2017 12:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dnkirill/a36f8468509e31b8a97bae17a13bee02 to your computer and use it in GitHub Desktop.
Save dnkirill/a36f8468509e31b8a97bae17a13bee02 to your computer and use it in GitHub Desktop.
def augmented_batch_generator(X, y, batch_size, rotations=True, affine=True,
shear_angle=0.0, scale_margins=[0.8, 1.5], p=0.35):
"""Augmented batch generator. Splits the dataset into batches and augments each
batch independently.
Args:
X: numpy array with images.
y: list of labels.
batch_size: the size of the output batch.
rotations: whether to apply `flips_rotations_augmentation` function to dataset.
affine: whether to apply `affine_transform` function to dataset.
shear_angle: `shear_angle` argument for `affine_transform` function.
scale_margins: `scale_margins` argument for `affine_transform` function.
p: `p` argument for `affine_transform` function.
"""
X_aug, y_aug = shuffle(X, y)
# Batch generation
for offset in range(0, X_aug.shape[0], batch_size):
end = offset + batch_size
batch_x, batch_y = X_aug[offset:end,...], y_aug[offset:end]
# Batch augmentation
if affine is True:
batch_x = affine_transform(batch_x, shear_angle=shear_angle, scale_margins=scale_margins, p=p)
if rotations is True:
batch_x, batch_y = flips_rotations_augmentation(batch_x, batch_y)
yield batch_x, batch_y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment