Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save mihir135/6d3dfeb7887cf1c987e84f8f9365b37c to your computer and use it in GitHub Desktop.
Save mihir135/6d3dfeb7887cf1c987e84f8f9365b37c to your computer and use it in GitHub Desktop.
This snippet will calculate the per-channel image mean and std in the train image set. It is plain simple and may not be efficient for large scale dataset.
"""
in this script, we calculate the image per channel mean and standard
deviation in the training set, do not calculate the statistics on the
whole dataset, as per here http://cs231n.github.io/neural-networks-2/#datapre
"""
import numpy as np
from os import listdir
from os.path import join, isdir
from glob import glob
import cv2
import timeit
# number of channels of the dataset image, 3 for color jpg, 1 for grayscale img
# you need to change it to reflect your dataset
CHANNEL_NUM = 3
def cal_dir_stat(root):
cls_dirs = [d for d in listdir(root) if isdir(join(root, d))]
pixel_num = 0 # store all pixel number in the dataset
channel_sum = np.zeros(CHANNEL_NUM)
channel_sum_squared = np.zeros(CHANNEL_NUM)
for idx, d in enumerate(cls_dirs):
print("#{} class".format(idx))
im_pths = glob(join(root, d, "*.jpg"))
for path in im_pths:
im = cv2.imread(path) # image in M*N*CHANNEL_NUM shape, channel in BGR order
im = im/255.0
pixel_num += (im.size/CHANNEL_NUM)
channel_sum += np.sum(im, axis=(0, 1))
channel_sum_squared += np.sum(np.square(im), axis=(0, 1))
bgr_mean = channel_sum / pixel_num
bgr_std = np.sqrt(channel_sum_squared / pixel_num - np.square(bgr_mean))
# change the format from bgr to rgb
rgb_mean = list(bgr_mean)[::-1]
rgb_std = list(bgr_std)[::-1]
return rgb_mean, rgb_std
# The script assumes that under train_root, there are separate directories for each class
# of training images.
train_root = "/hd1/jdhao/firearm-dataset/train/"
start = timeit.default_timer()
mean, std = cal_dir_stat(train_root)
end = timeit.default_timer()
print("elapsed time: {}".format(end-start))
print("mean:{}\nstd:{}".format(mean, std))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment