Skip to content

Instantly share code, notes, and snippets.

@jdhao
Last active September 20, 2023 06:36
Show Gist options
  • Save jdhao/9a86d4b9e4f79c5330d54de991461fd6 to your computer and use it in GitHub Desktop.
Save jdhao/9a86d4b9e4f79c5330d54de991461fd6 to your computer and use it in GitHub Desktop.
This snippet will calculate the per-channel image mean and std in the train image set. It is plain simple and may not be efficient for large scale dataset.
"""
in this script, we calculate the image per channel mean and standard
deviation in the training set, do not calculate the statistics on the
whole dataset, as per here http://cs231n.github.io/neural-networks-2/#datapre
"""
import numpy as np
from os import listdir
from os.path import join, isdir
from glob import glob
import cv2
import timeit
# number of channels of the dataset image, 3 for color jpg, 1 for grayscale img
# you need to change it to reflect your dataset
CHANNEL_NUM = 3
def cal_dir_stat(root):
cls_dirs = [d for d in listdir(root) if isdir(join(root, d))]
pixel_num = 0 # store all pixel number in the dataset
channel_sum = np.zeros(CHANNEL_NUM)
channel_sum_squared = np.zeros(CHANNEL_NUM)
for idx, d in enumerate(cls_dirs):
print("#{} class".format(idx))
im_pths = glob(join(root, d, "*.jpg"))
for path in im_pths:
im = cv2.imread(path) # image in M*N*CHANNEL_NUM shape, channel in BGR order
im = im/255.0
pixel_num += (im.size/CHANNEL_NUM)
channel_sum += np.sum(im, axis=(0, 1))
channel_sum_squared += np.sum(np.square(im), axis=(0, 1))
bgr_mean = channel_sum / pixel_num
bgr_std = np.sqrt(channel_sum_squared / pixel_num - np.square(bgr_mean))
# change the format from bgr to rgb
rgb_mean = list(bgr_mean)[::-1]
rgb_std = list(bgr_std)[::-1]
return rgb_mean, rgb_std
# The script assumes that under train_root, there are separate directories for each class
# of training images.
train_root = "/hd1/jdhao/firearm-dataset/train/"
start = timeit.default_timer()
mean, std = cal_dir_stat(train_root)
end = timeit.default_timer()
print("elapsed time: {}".format(end-start))
print("mean:{}\nstd:{}".format(mean, std))
@morganmcg1
Copy link

thanks for this :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment