Skip to content

Instantly share code, notes, and snippets.

@NumairSayed
Last active July 28, 2024 09:44
Show Gist options
  • Save NumairSayed/dfba5910139205e79f1b7022645418e6 to your computer and use it in GitHub Desktop.
Save NumairSayed/dfba5910139205e79f1b7022645418e6 to your computer and use it in GitHub Desktop.
template to transform with resizing, keeping the aspect ratio preserved and loading my image dataset.
"""
The inference transforms are available at GoogLeNet_Weights.IMAGENET1K_V1.transforms
and perform the following preprocessing operations: Accepts PIL.Image, batched (B, C, H, W)
and single (C, H, W) image torch.Tensor objects.
The images are resized to resize_size=[256] using interpolation=InterpolationMode.
BILINEAR, followed by a central crop of crop_size=[224]. Finally the values are first
rescaled to [0.0, 1.0] and then normalized using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
"""
import os
from PIL import Image
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader, Dataset
class ResizeWithPad:
def __init__(self, target_size):
self.target_size = target_size
def __call__(self, img):
img.thumbnail((self.target_size[0], self.target_size[1]), Image.ANTIALIAS)
new_img = Image.new('RGB', self.target_size, (0, 0, 0))
paste_position = ((self.target_size[0] - img.size[0]) // 2, (self.target_size[1] - img.size[1]) // 2)
new_img.paste(img, paste_position)
return new_img
transform = transforms.Compose([
ResizeWithPad((224, 224)), # Replace with your target dimensions
transforms.ToTensor()
])
class CustomImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_files = [os.path.join(root_dir, file) for file in os.listdir(root_dir) if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff'))]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = self.image_files[idx]
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
return image
dataset = CustomImageDataset(root_dir='path_to_your_root_folder', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in dataloader:
# Your training code here
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment