Skip to content

Instantly share code, notes, and snippets.

@asmekal
Created May 22, 2024 03:54
Show Gist options
  • Save asmekal/36a0a25cbc34d76fe0a28fc1e866bdb5 to your computer and use it in GitHub Desktop.
Save asmekal/36a0a25cbc34d76fe0a28fc1e866bdb5 to your computer and use it in GitHub Desktop.
invert some albumentations transforms
import numpy as np
import matplotlib.pyplot as plt
import cv2
import albumentations as A
class MaskInverter:
def __init__(self, original_shape):
self.original_shape = original_shape
def invert_horizontal_flip(self, mask):
return cv2.flip(mask, 1)
def invert_vertical_flip(self, mask):
return cv2.flip(mask, 0)
def invert_shift_scale_rotate(self, mask, params):
# Extract matrix
matrix = np.array(params['matrix'].params)
if matrix.shape != (3, 3):
raise ValueError(f"Projective transform matrix must be of shape (3, 3), but got {matrix.shape}")
# Inverse matrix transformation
inv_matrix = np.linalg.inv(matrix)
# Apply the inverse transformation
inverted_mask = cv2.warpPerspective(mask, inv_matrix, (self.original_shape[1], self.original_shape[0]), flags=cv2.INTER_NEAREST)
return inverted_mask
def invert_random_crop(self, mask, params):
crop_params = params
inverted_mask = np.zeros(self.original_shape[:2], dtype=np.uint8)
x1, y1, x2, y2 = A.augmentations.crops.functional.get_random_crop_coords(
height=self.original_shape[0], width=self.original_shape[1],
crop_height=mask.shape[0], crop_width=mask.shape[1],
h_start=crop_params['h_start'], w_start=crop_params['w_start'])
inverted_mask[y1:y2, x1:x2] = mask
return inverted_mask
def invert_center_crop(self, mask):
inverted_mask = np.zeros(self.original_shape[:2], dtype=np.uint8)
x1, y1, x2, y2 = A.augmentations.crops.functional.get_center_crop_coords(
height=self.original_shape[0], width=self.original_shape[1],
crop_height=mask.shape[0], crop_width=mask.shape[1])
inverted_mask[y1:y2, x1:x2] = mask
return inverted_mask
def invert_transform(self, aug_mask, replay):
# Initialize the inverted mask with the augmented mask
inverted_mask = aug_mask
had_resize = False # if >1 resize the assumption used in ~all resizes (on original image size) does not hold
# in case you want to support multiple crops / resizes at the same time - you'll have to give correct expected orig_size before transform
# Process the transformations in reverse order
for transform in reversed(replay['transforms']):
if transform['applied']:
if transform['__class_fullname__'] == 'RandomCrop':
assert not had_resize
inverted_mask = self.invert_random_crop(inverted_mask, transform['params'])
had_resize = True
elif transform['__class_fullname__'] == 'CenterCrop':
assert not had_resize
inverted_mask = self.invert_center_crop(inverted_mask)
had_resize = True
elif transform['__class_fullname__'] == 'ShiftScaleRotate':
inverted_mask = self.invert_shift_scale_rotate(inverted_mask, transform['params'])
elif transform['__class_fullname__'] == 'VerticalFlip':
inverted_mask = self.invert_vertical_flip(inverted_mask)
elif transform['__class_fullname__'] == 'HorizontalFlip':
inverted_mask = self.invert_horizontal_flip(inverted_mask)
else:
raise NotImplementedError(f"Unsupported transform: {transform['__class_fullname__']}")
return inverted_mask
# example usage
transform = A.ReplayCompose([
# A.HorizontalFlip(p=0.5),
# A.VerticalFlip(p=0.5),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=45, p=1.0),
A.RandomCrop(width=256, height=256, p=1.0),
# A.SmallestMaxSize(max_size=512),
# A.CenterCrop(512, 512),
], additional_targets={'mask': 'mask'})
# fn_test = "img.png"
# fn_test_mask = "mask.png"
# image = cv2.imread(fn_test)
# mask = cv2.imread(fn_test_mask, cv2.IMREAD_GRAYSCALE)
# Create dummy test images
image = np.ones((512, 512, 3), dtype=np.uint8) * 255 # White image
mask = np.zeros((512, 512), dtype=np.uint8) # Black mask
cv2.rectangle(mask, (150, 150), (350, 350), 255, -1) # White rectangle in the mask
# Apply the augmentations
augmented = transform(image=image, mask=mask)
aug_image = augmented['image']
aug_mask = augmented['mask']
replay = augmented['replay']
# Apply the inverse transformations to the augmented mask
inverted_mask = MaskInverter(image.shape).invert_transform(aug_mask, replay)
def augment(image, mask, transform):
augmented = transform(image=image, mask=mask)
aug_image = augmented['image']
aug_mask = augmented['mask']
replay = augmented['replay']
# Apply the inverse transformations to the augmented mask
inverted_mask = MaskInverter(image.shape).invert_transform(aug_mask, replay)
return aug_image, aug_mask, replay, inverted_mask
def plot_results5(image, mask, aug_image, aug_mask, inverted_mask):
plt.figure(figsize=(15, 5))
plt.subplot(1, 5, 1)
plt.title("Original Image")
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
plt.subplot(1, 5, 2)
plt.title("Original Mask")
plt.imshow(mask, cmap='gray')
plt.subplot(1, 5, 3)
plt.title("Augmented Image")
plt.imshow(cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB))
plt.subplot(1, 5, 4)
plt.title("Augmented Mask")
plt.imshow(aug_mask, cmap='gray')
plt.subplot(1, 5, 5)
plt.title("Inverted Mask")
plt.imshow(inverted_mask, cmap='gray')
plt.tight_layout()
plt.show()
plot_results5(image, mask, aug_image, aug_mask, inverted_mask)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment