Skip to content

Instantly share code, notes, and snippets.

@Rm1n90
Last active May 30, 2019 22:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Rm1n90/4b359a4aef94fa509f25421a8d6410d1 to your computer and use it in GitHub Desktop.
Save Rm1n90/4b359a4aef94fa509f25421a8d6410d1 to your computer and use it in GitHub Desktop.
Computing the mean and std
"""
computing the mean and the standard deviation per channel of any datasets with PyTorch
"""
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import transforms
Path_to_Dataset = " "
def mean_std(data):
count = 0
mean = torch.empty(3)
std = torch.empty(3)
# import pdb;
# pdb.set_trace()
for data, label in data:
b, c, h, w = data.size()
num_pixels = b * h * w
_sum = torch.sum(data, dim=[0, 2, 3])
square = torch.sum(data ** 2, dim=[0, 2, 3])
mean = (count * mean + _sum) / (cnt + num_pixels)
std = (count * std + square) / (cnt + num_pixels)
std = torch.sqrt(std - mean ** 2)
count += num_pixels
return mean, std
data_loader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder("Path_to_Dataset",
transform=transforms.Compose
([transforms.ToTensor()])),
batch_size=1, shuffle=False, num_workers=4)
mean, std = mean_std(data_loader)
print(mean, std)
# output --> tensor([X.XXX, X.XXX, X.XXX], [X.XXX, X.XXX, X.XXX])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment