Skip to content

Instantly share code, notes, and snippets.

@calebrob6
Last active May 28, 2024 07:52
Show Gist options
  • Save calebrob6/1ef1e64bd62b1274adf2c6f91e20d215 to your computer and use it in GitHub Desktop.
Save calebrob6/1ef1e64bd62b1274adf2c6f91e20d215 to your computer and use it in GitHub Desktop.
A class for computing online mean and variance of multi-dimensional arrays in PyTorch (i.e. for computing per-channel stats over large image datasets).
import torch
class RunningStatsButFast(torch.nn.Module):
def __init__(self, shape, dims):
"""Initializes the RunningStatsButFast method.
A PyTorch module that can be put on the GPU and calculate the multidimensional
mean and variance of inputs online in a numerically stable way. This is useful
for calculating the channel-wise mean and variance of a big dataset because you
don't have to load the entire dataset into memory.
Uses the "Parallel algorithm" from: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
Similar implementation here: https://github.com/openai/baselines/blob/master/baselines/common/running_mean_std.py#L5
Access the mean, variance, and standard deviation of the inputs with the
`mean`, `var`, and `std` attributes.
Example:
```
rs = RunningStatsButFast((12,), [0, 2, 3])
for inputs, _ in dataloader:
rs(inputs)
print(rs.mean)
print(rs.var)
print(rs.std)
```
Args:
shape: The shape of resulting mean and variance. For example, if you
are calculating the mean and variance over the 0th, 2nd, and 3rd
dimensions of inputs of size (64, 12, 256, 256), this should be 12.
dims: The dimensions of your input to calculate the mean and variance
over. In the above example, this should be [0, 2, 3].
"""
super(RunningStatsButFast, self).__init__()
self.register_buffer('mean', torch.zeros(shape))
self.register_buffer('var', torch.ones(shape))
self.register_buffer('std', torch.ones(shape))
self.register_buffer('count', torch.zeros(1))
self.dims = dims
def update(self, x):
with torch.no_grad():
batch_mean = torch.mean(x, dim=self.dims)
batch_var = torch.var(x, dim=self.dims)
batch_count = torch.tensor(x.shape[self.dims[0]], dtype=torch.float)
n_ab = self.count + batch_count
m_a = self.mean * self.count
m_b = batch_mean * batch_count
M2_a = self.var * self.count
M2_b = batch_var * batch_count
delta = batch_mean - self.mean
self.mean = (m_a + m_b) / (n_ab)
# we don't subtract -1 from the denominator to match the standard Numpy/PyTorch variances
self.var = (M2_a + M2_b + delta ** 2 * self.count * batch_count / (n_ab)) / (n_ab)
self.count += batch_count
self.std = torch.sqrt(self.var + 1e-8)
def forward(self, x):
self.update(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment