Skip to content

Instantly share code, notes, and snippets.

@farukcankaya
Created October 20, 2022 16:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save farukcankaya/dc8812bb685aa33a6790eee362d78478 to your computer and use it in GitHub Desktop.
Save farukcankaya/dc8812bb685aa33a6790eee362d78478 to your computer and use it in GitHub Desktop.
class InstanceColorJitterTransform(MultiModalTransform):
def __init__(self, color_operation: Callable, instance_rate: float, min_count_to_apply: int) -> None:
if not callable(color_operation):
raise ValueError("color_operation parameter should be callable")
super().__init__()
self._set_attributes(locals())
def apply_multi_modal(self, img, annos, *args):
instance_count = len(annos)
apply_count = max(self.min_count_to_apply, int(instance_count * self.instance_rate))
selected_idx = random.sample(list(range(instance_count)), apply_count)
for idx in selected_idx:
segm = annos[idx]["segmentation"]
bitmask = MultiModalTransform._get_bitmask(segm, img.shape[:2])
augmented_instance = self.color_operation(Image.fromarray(img * bitmask))
img = img * (1 - bitmask) + np.asarray(augmented_instance) * bitmask
return img, annos
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment