Skip to content

Instantly share code, notes, and snippets.

@morganmcg1
Created February 21, 2023 13:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save morganmcg1/c4dcdc9dc7130a00096739830bf69d5d to your computer and use it in GitHub Desktop.
Save morganmcg1/c4dcdc9dc7130a00096739830bf69d5d to your computer and use it in GitHub Desktop.
# Needs datasets, albumentations
from PIL import Image
from datasets import load_dataset
from datasets.download.download_manager import DownloadMode #, REUSE_DATASET_IF_EXISTS, REUSE_CACHE_IF_EXISTS
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from albumentations.augmentations.transforms import Normalize
import ast
split = "train"
REFEXP_DATASET_NAME = "ivelin/ui_refexp_saved"
train_ds = load_dataset(REFEXP_DATASET_NAME, split=split, num_proc=8, download_mode=DownloadMode.FORCE_REDOWNLOAD)
# train_ds
# VIEW THE IMAGE
train_ds[800]["image"]
def bbox_preprocess(bbox):
'''
convert string bboxes to dict
'''
if isinstance(bbox, str):
return ast.literal_eval(bbox)
else: return bbox
transform = A.Compose(
[
# rescale
A.augmentations.geometric.resize.LongestMaxSize(max_size=1024,
interpolation=1,
always_apply=True,
p=1),
# A.RandomBrightnessContrast(p=0.3),
# Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),
# max_pixel_value=255.0, always_apply=True, p=1.0),
Normalize(always_apply=True, p=1.0),
# Convert from BGR -> RGB
ToTensorV2(transpose_mask=False, always_apply=True, p=1.0)
],
bbox_params=A.BboxParams(format="pascal_voc", label_fields=['class_labels']),
)
def transforms(examples):
'''
Based on: https://huggingface.co/docs/datasets/object_detection
'''
images, bboxes = [], []
class_labels = ['dummy']
for image, bbox in zip(examples['image'], examples['target_bounding_box']):
# Get Bounding Boxes
width, height = image.size
bbox = bbox_preprocess(bbox)
xmin = bbox["xmin"]
xmax = bbox["xmax"]
ymin = bbox["ymin"]
ymax = bbox["ymax"]
boxes_xyxy = [[xmin*width, ymin*height, xmax*width, ymax*height, "dummy"]]
# Transform Image
if image.mode != "RGB":
image = np.array(image.convert("RGB"))
else:
image = np.array(image)
print(f"pil out: {image.shape}")
# RGB Image -> BGR Numpy array for Albumentations
image = np.flip(image, -1)
# print image stats before transform
print(image.shape)
print(image.max())
print(image[:,:,0].mean())
print(image[:,:,1].mean())
print(image[:,:,2].mean())
# Takes a numpy array and returns a PyTorch tensor (via ToTensorV2 transform)
out = transform(
image=image,
bboxes=boxes_xyxy,
class_labels=class_labels
)
# print image per channel stats after transform
print(image.shape)
print(out['image'][0,:,:].mean())
print(out['image'][1,:,:].mean())
print(out['image'][2,:,:].mean())
# flip BGR Numpy array channels -> RGB Tensor
images.append(torch.tensor(out['image']).flip(0))
bboxes.append(torch.tensor(out['bboxes'][0][:-1], dtype=torch.float16))
return {'image': images, 'bbox': bboxes}
# Will run transforms on the fly when an item is called
train_ds.set_transform(transforms)
# View image with PIL
example = temp_ds[800]
Image.fromarray(np.uint8(np.array(example["image"].permute(1, 2, 0))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment