Last active
July 28, 2024 09:44
-
-
Save NumairSayed/dfba5910139205e79f1b7022645418e6 to your computer and use it in GitHub Desktop.
template to transform with resizing, keeping the aspect ratio preserved and loading my image dataset.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
The inference transforms are available at GoogLeNet_Weights.IMAGENET1K_V1.transforms | |
and perform the following preprocessing operations: Accepts PIL.Image, batched (B, C, H, W) | |
and single (C, H, W) image torch.Tensor objects. | |
The images are resized to resize_size=[256] using interpolation=InterpolationMode. | |
BILINEAR, followed by a central crop of crop_size=[224]. Finally the values are first | |
rescaled to [0.0, 1.0] and then normalized using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]. | |
""" | |
import os | |
from PIL import Image | |
import torchvision.transforms as transforms | |
import torchvision.transforms.functional as F | |
from torch.utils.data import DataLoader, Dataset | |
class ResizeWithPad: | |
def __init__(self, target_size): | |
self.target_size = target_size | |
def __call__(self, img): | |
img.thumbnail((self.target_size[0], self.target_size[1]), Image.ANTIALIAS) | |
new_img = Image.new('RGB', self.target_size, (0, 0, 0)) | |
paste_position = ((self.target_size[0] - img.size[0]) // 2, (self.target_size[1] - img.size[1]) // 2) | |
new_img.paste(img, paste_position) | |
return new_img | |
transform = transforms.Compose([ | |
ResizeWithPad((224, 224)), # Replace with your target dimensions | |
transforms.ToTensor() | |
]) | |
class CustomImageDataset(Dataset): | |
def __init__(self, root_dir, transform=None): | |
self.root_dir = root_dir | |
self.transform = transform | |
self.image_files = [os.path.join(root_dir, file) for file in os.listdir(root_dir) if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff'))] | |
def __len__(self): | |
return len(self.image_files) | |
def __getitem__(self, idx): | |
img_path = self.image_files[idx] | |
image = Image.open(img_path).convert("RGB") | |
if self.transform: | |
image = self.transform(image) | |
return image | |
dataset = CustomImageDataset(root_dir='path_to_your_root_folder', transform=transform) | |
dataloader = DataLoader(dataset, batch_size=32, shuffle=True) | |
for batch in dataloader: | |
# Your training code here | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment