Last active
September 7, 2019 07:33
-
-
Save CMCDragonkai/457d587cdeeeb005a80e5125813b79d6 to your computer and use it in GitHub Desktop.
MNIST Data Preparation - Normalisation to Z-Score for CNNs #python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import itertools | |
import numpy as np | |
import keras.utils as utils | |
from pathlib import Path | |
from mnist import MNIST | |
def chunk_gen(size, iterable): | |
it = iter(iterable) | |
while True: | |
chunk = tuple(itertools.islice(it, size)) | |
if not chunk: | |
return | |
yield chunk | |
def data_gen(xs, ys, xs_shape, xs_mean, xs_stddev, ys_class_count, batch_size): | |
def resize(x): | |
return skimage.transform.resize( | |
x, | |
xs_shape, | |
order=1, | |
mode='constant', | |
preserve_range=True, | |
anti_aliasing=True) | |
# the last batch my overlap with the first samples due to cycling behaviour | |
for batch in chunk_gen(batch_size, itertools.cycle(zip(xs, ys))): | |
# batch is [(x,y), (x,y), (x,y)...] | |
(xs_batch, ys_batch) = zip(*batch) | |
# resize the batch | |
xs_batch = [resize(x) for x in xs_batch] | |
# calculate the z-scores | |
zs_batch = [(x - xs_mean) / xs_stddev for x in xs_batch] | |
ys_batch = [utils.to_categorical(y, ys_class_count) for y in ys_batch] | |
yield (np.array(zs_batch), np.array(ys_batch)) | |
mnist_dir = os.environ.get('MNIST_DATASET') | |
mnist_dir = Path(mnist_dir) | |
data = MNIST(mnist_dir, return_type='numpy') | |
data.gz = True | |
train_xs, train_ys = data.load_training() | |
val_xs, val_ys = data.load_testing() | |
# shape is NUM, ROW, COL, CHANNELS (channels_last format) | |
train_xs = np.reshape(train_xs, (train_xs.shape[0], 28, 28, 1)) | |
val_xs = np.reshape(val_xs, (val_xs.shape[0], 28, 28, 1)) | |
# the mean and stddev is calculated from the train set and used for the val_gen | |
# you can also calculate the mean and stddev from the val set or the entire set | |
# the differences between doing this depends on the statistics of the situation | |
# the axis here ensures that we get the channel mean and channel stddev | |
# mnist only has 1 channel, so this will just give us a list of 1 number | |
mean = train_xs.mean(axis=tuple(range(len(train_xs.shape) - 1))) | |
stddev = train_xs.std(axis=tuple(range(len(train_xs.shape) - 1))) | |
train_gen = data_gen(train_xs, train_ys, (224, 224), mean, stddev, 10, 2) | |
val_gen = data_gen(val_xs, val_ys, (224, 224), mean, stddev, 10, 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment