Skip to content

Instantly share code, notes, and snippets.

@revantteotia
Created May 9, 2021 15:10
Show Gist options
  • Save revantteotia/d68d111b98ef7baa3eabb4930c028efa to your computer and use it in GitHub Desktop.
Save revantteotia/d68d111b98ef7baa3eabb4930c028efa to your computer and use it in GitHub Desktop.
UnNormalize an image tensor
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized image.
"""
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
# The normalize code -> t.sub_(m).div_(s)
return tensor
# example
# unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
# unnormalized_img_tensor = unorm(normalized_img_tensor)
# unnormalized_img = torchvision.transforms.ToPILImage()(unnormalized_img_tensor).convert("RGB")
# import matplotlib.pyplot as plt
# plt.imshow(unnormalized_img)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment