Skip to content

Instantly share code, notes, and snippets.

@agastidukare
Created April 1, 2020 01:32
Show Gist options
  • Save agastidukare/4e2d6c6a85c26ca99ef1522b9f04fc27 to your computer and use it in GitHub Desktop.
Save agastidukare/4e2d6c6a85c26ca99ef1522b9f04fc27 to your computer and use it in GitHub Desktop.
_DATA = "/tmp/"
def _download(url, filename):
"""Download a url to a file in the JAX data temp directory."""
if not path.exists(_DATA):
os.makedirs(_DATA)
out_file = path.join(_DATA, filename)
if not path.isfile(out_file):
urllib.request.urlretrieve(url, out_file)
print("downloaded {} to {}".format(url, _DATA))
def _partial_flatten(x):
"""Flatten all but the first dimension of an ndarray."""
return numpy.reshape(x, (x.shape[0], -1))
def _one_hot(x, k, dtype=numpy.float32):
"""Create a one-hot encoding of x of size k."""
return numpy.array(x[:, None] == numpy.arange(k), dtype)
def mnist_raw():
"""Download and parse the raw MNIST dataset."""
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
def parse_labels(filename):
with gzip.open(filename, "rb") as fh:
_ = struct.unpack(">II", fh.read(8))
return numpy.array(array.array("B", fh.read()), dtype=numpy.uint8)
def parse_images(filename):
with gzip.open(filename, "rb") as fh:
_, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
return numpy.array(array.array("B", fh.read()),
dtype=numpy.uint8).reshape(num_data, rows, cols)
for filename in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]:
_download(base_url + filename, filename)
train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz"))
train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz"))
test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz"))
test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz"))
return train_images, train_labels, test_images, test_labels
def mnist(create_outliers=False):
"""Download, parse and process MNIST data to unit scale and one-hot labels."""
train_images, train_labels, test_images, test_labels = mnist_raw()
train_images = _partial_flatten(train_images) / numpy.float32(255.)
test_images = _partial_flatten(test_images) / numpy.float32(255.)
train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)
if create_outliers:
mum_outliers = 30000
perm = numpy.random.RandomState(0).permutation(mum_outliers)
train_images[:mum_outliers] = train_images[:mum_outliers][perm]
return train_images, train_labels, test_images, test_labels
def shape_as_image(images, labels, dummy_dim=False):
target_shape = (-1, 1, 28, 28, 1) if dummy_dim else (-1, 28, 28, 1)
return np.reshape(images, target_shape), labels
train_images, train_labels, test_images, test_labels = mnist(create_outliers=False)
num_train = train_images.shape[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment