Created
April 1, 2020 01:32
-
-
Save agastidukare/4e2d6c6a85c26ca99ef1522b9f04fc27 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
_DATA = "/tmp/" | |
def _download(url, filename): | |
"""Download a url to a file in the JAX data temp directory.""" | |
if not path.exists(_DATA): | |
os.makedirs(_DATA) | |
out_file = path.join(_DATA, filename) | |
if not path.isfile(out_file): | |
urllib.request.urlretrieve(url, out_file) | |
print("downloaded {} to {}".format(url, _DATA)) | |
def _partial_flatten(x): | |
"""Flatten all but the first dimension of an ndarray.""" | |
return numpy.reshape(x, (x.shape[0], -1)) | |
def _one_hot(x, k, dtype=numpy.float32): | |
"""Create a one-hot encoding of x of size k.""" | |
return numpy.array(x[:, None] == numpy.arange(k), dtype) | |
def mnist_raw(): | |
"""Download and parse the raw MNIST dataset.""" | |
# CVDF mirror of http://yann.lecun.com/exdb/mnist/ | |
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" | |
def parse_labels(filename): | |
with gzip.open(filename, "rb") as fh: | |
_ = struct.unpack(">II", fh.read(8)) | |
return numpy.array(array.array("B", fh.read()), dtype=numpy.uint8) | |
def parse_images(filename): | |
with gzip.open(filename, "rb") as fh: | |
_, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) | |
return numpy.array(array.array("B", fh.read()), | |
dtype=numpy.uint8).reshape(num_data, rows, cols) | |
for filename in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", | |
"t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]: | |
_download(base_url + filename, filename) | |
train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz")) | |
train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz")) | |
test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz")) | |
test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz")) | |
return train_images, train_labels, test_images, test_labels | |
def mnist(create_outliers=False): | |
"""Download, parse and process MNIST data to unit scale and one-hot labels.""" | |
train_images, train_labels, test_images, test_labels = mnist_raw() | |
train_images = _partial_flatten(train_images) / numpy.float32(255.) | |
test_images = _partial_flatten(test_images) / numpy.float32(255.) | |
train_labels = _one_hot(train_labels, 10) | |
test_labels = _one_hot(test_labels, 10) | |
if create_outliers: | |
mum_outliers = 30000 | |
perm = numpy.random.RandomState(0).permutation(mum_outliers) | |
train_images[:mum_outliers] = train_images[:mum_outliers][perm] | |
return train_images, train_labels, test_images, test_labels | |
def shape_as_image(images, labels, dummy_dim=False): | |
target_shape = (-1, 1, 28, 28, 1) if dummy_dim else (-1, 28, 28, 1) | |
return np.reshape(images, target_shape), labels | |
train_images, train_labels, test_images, test_labels = mnist(create_outliers=False) | |
num_train = train_images.shape[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment