Skip to content

Instantly share code, notes, and snippets.

@marcosfelt
Created January 17, 2023 12:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcosfelt/3f10f01c777672b3a51d46f1a16cf371 to your computer and use it in GitHub Desktop.
Save marcosfelt/3f10f01c777672b3a51d46f1a16cf371 to your computer and use it in GitHub Desktop.
Calculate mean and standard deviation using batch updates
import numpy as np
x = np.arange(100)
N = 0
mean = 0
std = 0
for i in range(11):
batch = x[10*i:10*(i+1)]
k = len(batch)
N += k
old_mean = mean
batch_mean = batch.mean()
old_std = std
batch_std = batch.std()
# Mean update
new_mean = (N - k) / N * old_mean + k / N * batch_mean
# Variance update
new_var = (
(N - k) / N * (old_std**2 + old_mean**2)
+ k / N * (batch_std**2 + batch_mean**2)
- new_mean**2
)
mean = new_mean
std = np.sqrt(new_var)
print(f"Mean: {x.mean()} | Running mean: {mean}")
print(f"Std: {x.std():.03f} | Running std: {std:.03f}")
# Mean: 50.0 | Running mean: 50.0
# Std: 29.155 | Running std: 29.155
@marcosfelt
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment