Skip to content

Instantly share code, notes, and snippets.

@rexlow
Created April 22, 2019 05:03
Show Gist options
  • Save rexlow/086fcbc9e4ba6ab95cfc17f547b40dc5 to your computer and use it in GitHub Desktop.
Save rexlow/086fcbc9e4ba6ab95cfc17f547b40dc5 to your computer and use it in GitHub Desktop.
Calculate data distribution
def calculate_distribution(loader):
step = 0
mean = torch.empty(3)
std = torch.empty(3)
for data, label in loader:
b, c, h, w = data.shape
total_pixels = b * h * w
sum_ = torch.sum(data, dim=[0, 2, 3])
sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3])
mean = (step * mean + sum_) / (step + total_pixels)
std = (step * std + sum_of_square) / (step + total_pixels)
step += total_pixels
return mean, torch.sqrt(std - mean ** 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment