Created
November 11, 2021 08:11
-
-
Save spirosdim/79fc88231fffec347f1ad5d14a36b5a8 to your computer and use it in GitHub Desktop.
Find the mean and std from images for image normalization using pytorch.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import PIL, os | |
from tqdm import tqdm | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
class ImgDataset(Dataset): | |
def __init__(self, data_dir, extension='.jpg', transform=None): | |
self.data_dir = data_dir | |
self.image_files = [f for f in os.listdir(data_dir) if f.endswith(extension)] | |
self.transform = transform | |
def __len__(self): | |
return len(self.image_files) | |
def __getitem__(self, index): | |
image = torch.from_numpy(np.array(PIL.Image.open(self.data_dir+'/'+self.image_files[index]),dtype=np.float32)) | |
if self.transform: | |
image = self.transform(image) | |
return image | |
def compute_mean_std(dataloader): | |
''' | |
We assume that the images of the dataloader have the same height and width | |
source: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/Basics/pytorch_std_mean.py | |
''' | |
# var[X] = E[X**2] - E[X]**2 | |
channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0 | |
for batch_images in tqdm(dataloader): # (B,H,W,C) | |
channels_sum += torch.mean(batch_images, dim=[0, 1, 2]) | |
channels_sqrd_sum += torch.mean(batch_images ** 2, dim=[0, 1, 2]) | |
num_batches += 1 | |
mean = channels_sum / num_batches | |
std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5 | |
return mean, std |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# create the DataLoader | |
data_dir = '/images_folder' #we assume that all images are in this folder, modify ImgDataset for your own case. | |
dataset = ImgDataset(data_dir+'/train') | |
train_loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False, | |
num_workers=2, pin_memory=True) | |
# output | |
total_mean, total_std = compute_mean_std(train_loader) | |
print('mean (RGB): ' + str(total_mean)) | |
print('std (RGB): ' + str(total_std)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment