Skip to content

Instantly share code, notes, and snippets.

@sharma0611
Last active June 23, 2020 14:44
Show Gist options
  • Save sharma0611/81e895698564bf804f05f001fe3807ef to your computer and use it in GitHub Desktop.
Save sharma0611/81e895698564bf804f05f001fe3807ef to your computer and use it in GitHub Desktop.
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
from modules.cifar10 import data_loader
import matplotlib.pyplot as plt
# modules.utils.py
class DeNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized image.
"""
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
# The normalize code -> t.sub_(m).div_(s)
return tensor
# modules.cifar10.py
def denormalize_transform():
denormal = DeNormalize(
mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
return denormal
# Sanity Check in a notebook cell
train_loader, val_loader = data_loader('./data', batch_size=3)
(image, target) = iter(train_loader).next()
# Denormalize RGB values since our data loader has a normalize
# Values above or below [0,1] are clipped in the RGB image displayed
transform = denormalize_transform()
transform(image)
grid_img = torchvision.utils.make_grid(image, nrow=3)
plt.imshow(grid_img.permute(1, 2, 0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment