Created
May 22, 2024 03:54
-
-
Save asmekal/36a0a25cbc34d76fe0a28fc1e866bdb5 to your computer and use it in GitHub Desktop.
invert some albumentations transforms
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import cv2 | |
import albumentations as A | |
class MaskInverter: | |
def __init__(self, original_shape): | |
self.original_shape = original_shape | |
def invert_horizontal_flip(self, mask): | |
return cv2.flip(mask, 1) | |
def invert_vertical_flip(self, mask): | |
return cv2.flip(mask, 0) | |
def invert_shift_scale_rotate(self, mask, params): | |
# Extract matrix | |
matrix = np.array(params['matrix'].params) | |
if matrix.shape != (3, 3): | |
raise ValueError(f"Projective transform matrix must be of shape (3, 3), but got {matrix.shape}") | |
# Inverse matrix transformation | |
inv_matrix = np.linalg.inv(matrix) | |
# Apply the inverse transformation | |
inverted_mask = cv2.warpPerspective(mask, inv_matrix, (self.original_shape[1], self.original_shape[0]), flags=cv2.INTER_NEAREST) | |
return inverted_mask | |
def invert_random_crop(self, mask, params): | |
crop_params = params | |
inverted_mask = np.zeros(self.original_shape[:2], dtype=np.uint8) | |
x1, y1, x2, y2 = A.augmentations.crops.functional.get_random_crop_coords( | |
height=self.original_shape[0], width=self.original_shape[1], | |
crop_height=mask.shape[0], crop_width=mask.shape[1], | |
h_start=crop_params['h_start'], w_start=crop_params['w_start']) | |
inverted_mask[y1:y2, x1:x2] = mask | |
return inverted_mask | |
def invert_center_crop(self, mask): | |
inverted_mask = np.zeros(self.original_shape[:2], dtype=np.uint8) | |
x1, y1, x2, y2 = A.augmentations.crops.functional.get_center_crop_coords( | |
height=self.original_shape[0], width=self.original_shape[1], | |
crop_height=mask.shape[0], crop_width=mask.shape[1]) | |
inverted_mask[y1:y2, x1:x2] = mask | |
return inverted_mask | |
def invert_transform(self, aug_mask, replay): | |
# Initialize the inverted mask with the augmented mask | |
inverted_mask = aug_mask | |
had_resize = False # if >1 resize the assumption used in ~all resizes (on original image size) does not hold | |
# in case you want to support multiple crops / resizes at the same time - you'll have to give correct expected orig_size before transform | |
# Process the transformations in reverse order | |
for transform in reversed(replay['transforms']): | |
if transform['applied']: | |
if transform['__class_fullname__'] == 'RandomCrop': | |
assert not had_resize | |
inverted_mask = self.invert_random_crop(inverted_mask, transform['params']) | |
had_resize = True | |
elif transform['__class_fullname__'] == 'CenterCrop': | |
assert not had_resize | |
inverted_mask = self.invert_center_crop(inverted_mask) | |
had_resize = True | |
elif transform['__class_fullname__'] == 'ShiftScaleRotate': | |
inverted_mask = self.invert_shift_scale_rotate(inverted_mask, transform['params']) | |
elif transform['__class_fullname__'] == 'VerticalFlip': | |
inverted_mask = self.invert_vertical_flip(inverted_mask) | |
elif transform['__class_fullname__'] == 'HorizontalFlip': | |
inverted_mask = self.invert_horizontal_flip(inverted_mask) | |
else: | |
raise NotImplementedError(f"Unsupported transform: {transform['__class_fullname__']}") | |
return inverted_mask | |
# example usage | |
transform = A.ReplayCompose([ | |
# A.HorizontalFlip(p=0.5), | |
# A.VerticalFlip(p=0.5), | |
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=45, p=1.0), | |
A.RandomCrop(width=256, height=256, p=1.0), | |
# A.SmallestMaxSize(max_size=512), | |
# A.CenterCrop(512, 512), | |
], additional_targets={'mask': 'mask'}) | |
# fn_test = "img.png" | |
# fn_test_mask = "mask.png" | |
# image = cv2.imread(fn_test) | |
# mask = cv2.imread(fn_test_mask, cv2.IMREAD_GRAYSCALE) | |
# Create dummy test images | |
image = np.ones((512, 512, 3), dtype=np.uint8) * 255 # White image | |
mask = np.zeros((512, 512), dtype=np.uint8) # Black mask | |
cv2.rectangle(mask, (150, 150), (350, 350), 255, -1) # White rectangle in the mask | |
# Apply the augmentations | |
augmented = transform(image=image, mask=mask) | |
aug_image = augmented['image'] | |
aug_mask = augmented['mask'] | |
replay = augmented['replay'] | |
# Apply the inverse transformations to the augmented mask | |
inverted_mask = MaskInverter(image.shape).invert_transform(aug_mask, replay) | |
def augment(image, mask, transform): | |
augmented = transform(image=image, mask=mask) | |
aug_image = augmented['image'] | |
aug_mask = augmented['mask'] | |
replay = augmented['replay'] | |
# Apply the inverse transformations to the augmented mask | |
inverted_mask = MaskInverter(image.shape).invert_transform(aug_mask, replay) | |
return aug_image, aug_mask, replay, inverted_mask | |
def plot_results5(image, mask, aug_image, aug_mask, inverted_mask): | |
plt.figure(figsize=(15, 5)) | |
plt.subplot(1, 5, 1) | |
plt.title("Original Image") | |
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
plt.subplot(1, 5, 2) | |
plt.title("Original Mask") | |
plt.imshow(mask, cmap='gray') | |
plt.subplot(1, 5, 3) | |
plt.title("Augmented Image") | |
plt.imshow(cv2.cvtColor(aug_image, cv2.COLOR_BGR2RGB)) | |
plt.subplot(1, 5, 4) | |
plt.title("Augmented Mask") | |
plt.imshow(aug_mask, cmap='gray') | |
plt.subplot(1, 5, 5) | |
plt.title("Inverted Mask") | |
plt.imshow(inverted_mask, cmap='gray') | |
plt.tight_layout() | |
plt.show() | |
plot_results5(image, mask, aug_image, aug_mask, inverted_mask) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment