-
-
Save ischlag/41d15424e7989b936c1609b53edd1390 to your computer and use it in GitHub Desktop.
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import gzip | |
import os | |
import sys | |
import time | |
from six.moves import urllib | |
from six.moves import xrange # pylint: disable=redefined-builtin | |
from scipy.misc import imsave | |
import tensorflow as tf | |
import numpy as np | |
import csv | |
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' | |
WORK_DIRECTORY = 'data' | |
IMAGE_SIZE = 28 | |
NUM_CHANNELS = 1 | |
PIXEL_DEPTH = 255 | |
NUM_LABELS = 10 | |
def maybe_download(filename): | |
"""Download the data from Yann's website, unless it's already here.""" | |
if not tf.gfile.Exists(WORK_DIRECTORY): | |
tf.gfile.MakeDirs(WORK_DIRECTORY) | |
filepath = os.path.join(WORK_DIRECTORY, filename) | |
if not tf.gfile.Exists(filepath): | |
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) | |
with tf.gfile.GFile(filepath) as f: | |
size = f.Size() | |
print('Successfully downloaded', filename, size, 'bytes.') | |
return filepath | |
def extract_data(filename, num_images): | |
"""Extract the images into a 4D tensor [image index, y, x, channels]. | |
Values are rescaled from [0, 255] down to [-0.5, 0.5]. | |
""" | |
print('Extracting', filename) | |
with gzip.open(filename) as bytestream: | |
bytestream.read(16) | |
buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images) | |
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32) | |
#data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH | |
data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, 1) | |
return data | |
def extract_labels(filename, num_images): | |
"""Extract the labels into a vector of int64 label IDs.""" | |
print('Extracting', filename) | |
with gzip.open(filename) as bytestream: | |
bytestream.read(8) | |
buf = bytestream.read(1 * num_images) | |
labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64) | |
return labels | |
train_data_filename = maybe_download('train-images-idx3-ubyte.gz') | |
train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz') | |
test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz') | |
test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz') | |
# Extract it into np arrays. | |
train_data = extract_data(train_data_filename, 60000) | |
train_labels = extract_labels(train_labels_filename, 60000) | |
test_data = extract_data(test_data_filename, 10000) | |
test_labels = extract_labels(test_labels_filename, 10000) | |
if not os.path.isdir("mnist/train-images"): | |
os.makedirs("mnist/train-images") | |
if not os.path.isdir("mnist/test-images"): | |
os.makedirs("mnist/test-images") | |
# process train data | |
with open("mnist/train-labels.csv", 'wb') as csvFile: | |
writer = csv.writer(csvFile, delimiter=',', quotechar='"') | |
for i in range(len(train_data)): | |
imsave("mnist/train-images/" + str(i) + ".jpg", train_data[i][:,:,0]) | |
writer.writerow(["train-images/" + str(i) + ".jpg", train_labels[i]]) | |
# repeat for test data | |
with open("mnist/test-labels.csv", 'wb') as csvFile: | |
writer = csv.writer(csvFile, delimiter=',', quotechar='"') | |
for i in range(len(test_data)): | |
imsave("mnist/test-images/" + str(i) + ".jpg", test_data[i][:,:,0]) | |
writer.writerow(["test-images/" + str(i) + ".jpg", test_labels[i]]) | |
When i run Line 83 [writer.writerow(["train-images/" + str(i) + ".jpg", train_labels[i]]) ] I get "a bytes-like object is required, not 'str'". Please help! Thx!
@temiwale88 you just need to change
with open("mnist/train-labels.csv", 'wb') as csvFile:
to
with open("mnist/train-labels.csv", 'w') as csvFile:
and do the same for the other line.
It was expecting an array of bytes to write since you opened the file with 'wb' instead of 'w' ;)
This did the trick for me ;)
Anyone knows how to do the programming for EMNIST? I think classes will change from 10 to 47, anything else? Because shape doesn't match.
In line 83,
writer.writerow(["train-images/" + str(i) + ".jpg", train_labels[i]])
PermissionError: [Errno 13] Permission denied
Please help! Thanks in advance
Line 32 must be
size = f.size()