Skip to content

Instantly share code, notes, and snippets.

@takatakamanbou takatakamanbou/mnist0117.py Secret
Last active Aug 29, 2015

Embed
What would you like to do?
import struct
import numpy as np
class MNIST:
def __init__( self, LT ):
if LT == 'L':
self.fnLabel = 'train-labels-idx1-ubyte'
self.fnImage = 'train-images-idx3-ubyte'
else:
self.fnLabel = 't10k-labels-idx1-ubyte'
self.fnImage = 't10k-images-idx3-ubyte'
def getLabel( self ):
return _readLabel( self.fnLabel )
def getImage( self ):
return _readImage( self.fnImage )
##### reading the label file
#
def _readLabel( fnLabel ):
f = open( fnLabel, 'r' )
### header (two 4B integers, magic number(2049) & number of items)
#
header = f.read( 8 )
mn, num = struct.unpack( '>2i', header ) # MSB first (bigendian)
assert mn == 2049
#print mn, num
### labels (unsigned byte)
#
label = np.array( struct.unpack( '>%dB' % num, f.read() ), dtype = int )
f.close()
return label
##### reading the image file
#
def _readImage( fnImage ):
f = open( fnImage, 'r' )
### header (four 4B integers, magic number(2051), #images, #rows, and #cols
#
header = f.read( 16 )
mn, num, nrow, ncol = struct.unpack( '>4i', header ) # MSB first (bigendian)
assert mn == 2051
#print mn, num, nrow, ncol
### pixels (unsigned byte)
#
pixel = np.empty( ( num, nrow, ncol ) )
npixel = nrow * ncol
for i in range( num ):
buf = struct.unpack( '>%dB' % npixel, f.read( npixel ) )
pixel[i, :, :] = np.asarray( buf ).reshape( ( nrow, ncol ) )
f.close()
return pixel
if __name__ == '__main__':
print '# MNIST training data'
mnist = MNIST( 'L' )
lab = mnist.getLabel()
dat = mnist.getImage()
print lab.shape, dat.shape
print '# MNIST test data'
mnist = MNIST( 'T' )
lab = mnist.getLabel()
dat = mnist.getImage()
print lab.shape, dat.shape
import cv2
import numpy as np
import mnist0117 as mnist
nx = 10
ny = 10
gap = 4
mn = mnist.MNIST( 'L' )
dat = mn.getImage()[:nx*ny]
nrow, ncol = dat.shape[1:]
width = nx * ( ncol + gap ) + gap
height = ny * ( nrow + gap ) + gap
img = np.zeros( ( height, width ), dtype = int ) + 128
for iy in range( ny ):
lty = iy * ( nrow + gap ) + gap
for ix in range( nx ):
ltx = ix * ( ncol + gap ) + gap
img[lty:lty+nrow, ltx:ltx+ncol] = dat[iy*nx+ix]
cv2.imwrite( 'hoge.png', img )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.