Skip to content

Instantly share code, notes, and snippets.

@lelayf
Created August 30, 2014 20:16
Show Gist options
  • Save lelayf/6bf4901b036ffc57e0c7 to your computer and use it in GitHub Desktop.
Save lelayf/6bf4901b036ffc57e0c7 to your computer and use it in GitHub Desktop.
Predict MNIST figures with Caffe trained model - Step 1 : prepare input data
import os, struct
import numpy as np
from array import array as pyarray
from numpy import append, array, int8, uint8, zeros
def read(digits, dataset = "training", path = "."):
"""
Loads MNIST files into 3D numpy arrays
Adapted from: http://abel.ee.ucla.edu/cvxopt/_downloads/mnist.py
"""
if dataset is "training":
fname_img = os.path.join(path, 'train-images-idx3-ubyte')
fname_lbl = os.path.join(path, 'train-labels-idx1-ubyte')
elif dataset is "testing":
fname_img = os.path.join(path, 't10k-images-idx3-ubyte')
fname_lbl = os.path.join(path, 't10k-labels-idx1-ubyte')
else:
raise ValueError, "dataset must be 'testing' or 'training'"
flbl = open(fname_lbl, 'rb')
magic_nr, size = struct.unpack(">II", flbl.read(8))
lbl = pyarray("b", flbl.read())
flbl.close()
fimg = open(fname_img, 'rb')
magic_nr, size, rows, cols = struct.unpack(">IIII", fimg.read(16))
img = pyarray("B", fimg.read())
fimg.close()
ind = [ k for k in xrange(size) if lbl[k] in digits ]
N = len(ind)
images = zeros((N, rows, cols), dtype=uint8)
labels = zeros((N, 1), dtype=int8)
for i in xrange(len(ind)):
images[i] = array(img[ ind[i]*rows*cols : (ind[i]+1)*rows*cols ]).reshape((rows, cols))
labels[i] = lbl[ind[i]]
return images, labels
images, labels = read([2], 'training', '/home/ubuntu/Repositories/caffe/data/mnist')
outputs = zeros((100,28,28,1),dtype=float)
outputs[:,:,:,0] = images[0:100] * (1.0/256)
print(labels[0:100])
np.save('mnist-predict-100-twos.npy',outputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment