Last active
May 30, 2019 22:58
-
-
Save Rm1n90/4b359a4aef94fa509f25421a8d6410d1 to your computer and use it in GitHub Desktop.
Computing the mean and std
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
computing the mean and the standard deviation per channel of any datasets with PyTorch | |
""" | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torchvision | |
from torch.utils.data import DataLoader | |
from torch.utils.data.dataset import Dataset | |
from torchvision import transforms | |
Path_to_Dataset = " " | |
def mean_std(data): | |
count = 0 | |
mean = torch.empty(3) | |
std = torch.empty(3) | |
# import pdb; | |
# pdb.set_trace() | |
for data, label in data: | |
b, c, h, w = data.size() | |
num_pixels = b * h * w | |
_sum = torch.sum(data, dim=[0, 2, 3]) | |
square = torch.sum(data ** 2, dim=[0, 2, 3]) | |
mean = (count * mean + _sum) / (cnt + num_pixels) | |
std = (count * std + square) / (cnt + num_pixels) | |
std = torch.sqrt(std - mean ** 2) | |
count += num_pixels | |
return mean, std | |
data_loader = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder("Path_to_Dataset", | |
transform=transforms.Compose | |
([transforms.ToTensor()])), | |
batch_size=1, shuffle=False, num_workers=4) | |
mean, std = mean_std(data_loader) | |
print(mean, std) | |
# output --> tensor([X.XXX, X.XXX, X.XXX], [X.XXX, X.XXX, X.XXX]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment