Skip to content

Instantly share code, notes, and snippets.

@CMCDragonkai
Last active September 7, 2019 07:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save CMCDragonkai/457d587cdeeeb005a80e5125813b79d6 to your computer and use it in GitHub Desktop.
Save CMCDragonkai/457d587cdeeeb005a80e5125813b79d6 to your computer and use it in GitHub Desktop.
MNIST Data Preparation - Normalisation to Z-Score for CNNs #python
import os
import itertools
import numpy as np
import keras.utils as utils
from pathlib import Path
from mnist import MNIST
def chunk_gen(size, iterable):
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, size))
if not chunk:
return
yield chunk
def data_gen(xs, ys, xs_shape, xs_mean, xs_stddev, ys_class_count, batch_size):
def resize(x):
return skimage.transform.resize(
x,
xs_shape,
order=1,
mode='constant',
preserve_range=True,
anti_aliasing=True)
# the last batch my overlap with the first samples due to cycling behaviour
for batch in chunk_gen(batch_size, itertools.cycle(zip(xs, ys))):
# batch is [(x,y), (x,y), (x,y)...]
(xs_batch, ys_batch) = zip(*batch)
# resize the batch
xs_batch = [resize(x) for x in xs_batch]
# calculate the z-scores
zs_batch = [(x - xs_mean) / xs_stddev for x in xs_batch]
ys_batch = [utils.to_categorical(y, ys_class_count) for y in ys_batch]
yield (np.array(zs_batch), np.array(ys_batch))
mnist_dir = os.environ.get('MNIST_DATASET')
mnist_dir = Path(mnist_dir)
data = MNIST(mnist_dir, return_type='numpy')
data.gz = True
train_xs, train_ys = data.load_training()
val_xs, val_ys = data.load_testing()
# shape is NUM, ROW, COL, CHANNELS (channels_last format)
train_xs = np.reshape(train_xs, (train_xs.shape[0], 28, 28, 1))
val_xs = np.reshape(val_xs, (val_xs.shape[0], 28, 28, 1))
# the mean and stddev is calculated from the train set and used for the val_gen
# you can also calculate the mean and stddev from the val set or the entire set
# the differences between doing this depends on the statistics of the situation
# the axis here ensures that we get the channel mean and channel stddev
# mnist only has 1 channel, so this will just give us a list of 1 number
mean = train_xs.mean(axis=tuple(range(len(train_xs.shape) - 1)))
stddev = train_xs.std(axis=tuple(range(len(train_xs.shape) - 1)))
train_gen = data_gen(train_xs, train_ys, (224, 224), mean, stddev, 10, 2)
val_gen = data_gen(val_xs, val_ys, (224, 224), mean, stddev, 10, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment