Skip to content

Instantly share code, notes, and snippets.

@spirosdim
Created November 11, 2021 08:11
Show Gist options
  • Save spirosdim/79fc88231fffec347f1ad5d14a36b5a8 to your computer and use it in GitHub Desktop.
Save spirosdim/79fc88231fffec347f1ad5d14a36b5a8 to your computer and use it in GitHub Desktop.
Find the mean and std from images for image normalization using pytorch.
import PIL, os
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset
class ImgDataset(Dataset):
def __init__(self, data_dir, extension='.jpg', transform=None):
self.data_dir = data_dir
self.image_files = [f for f in os.listdir(data_dir) if f.endswith(extension)]
self.transform = transform
def __len__(self):
return len(self.image_files)
def __getitem__(self, index):
image = torch.from_numpy(np.array(PIL.Image.open(self.data_dir+'/'+self.image_files[index]),dtype=np.float32))
if self.transform:
image = self.transform(image)
return image
def compute_mean_std(dataloader):
'''
We assume that the images of the dataloader have the same height and width
source: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/Basics/pytorch_std_mean.py
'''
# var[X] = E[X**2] - E[X]**2
channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0
for batch_images in tqdm(dataloader): # (B,H,W,C)
channels_sum += torch.mean(batch_images, dim=[0, 1, 2])
channels_sqrd_sum += torch.mean(batch_images ** 2, dim=[0, 1, 2])
num_batches += 1
mean = channels_sum / num_batches
std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5
return mean, std
# create the DataLoader
data_dir = '/images_folder' #we assume that all images are in this folder, modify ImgDataset for your own case.
dataset = ImgDataset(data_dir+'/train')
train_loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False,
num_workers=2, pin_memory=True)
# output
total_mean, total_std = compute_mean_std(train_loader)
print('mean (RGB): ' + str(total_mean))
print('std (RGB): ' + str(total_std))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment