-
-
Save takatakamanbou/8b723dc2c65f6f18e58b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import struct | |
import numpy as np | |
class MNIST: | |
def __init__( self, LT ): | |
if LT == 'L': | |
self.fnLabel = 'train-labels-idx1-ubyte' | |
self.fnImage = 'train-images-idx3-ubyte' | |
else: | |
self.fnLabel = 't10k-labels-idx1-ubyte' | |
self.fnImage = 't10k-images-idx3-ubyte' | |
def getLabel( self ): | |
return _readLabel( self.fnLabel ) | |
def getImage( self ): | |
return _readImage( self.fnImage ) | |
##### reading the label file | |
# | |
def _readLabel( fnLabel ): | |
f = open( fnLabel, 'r' ) | |
### header (two 4B integers, magic number(2049) & number of items) | |
# | |
header = f.read( 8 ) | |
mn, num = struct.unpack( '>2i', header ) # MSB first (bigendian) | |
assert mn == 2049 | |
#print mn, num | |
### labels (unsigned byte) | |
# | |
label = np.array( struct.unpack( '>%dB' % num, f.read() ), dtype = int ) | |
f.close() | |
return label | |
##### reading the image file | |
# | |
def _readImage( fnImage ): | |
f = open( fnImage, 'r' ) | |
### header (four 4B integers, magic number(2051), #images, #rows, and #cols | |
# | |
header = f.read( 16 ) | |
mn, num, nrow, ncol = struct.unpack( '>4i', header ) # MSB first (bigendian) | |
assert mn == 2051 | |
#print mn, num, nrow, ncol | |
### pixels (unsigned byte) | |
# | |
pixel = np.empty( ( num, nrow, ncol ) ) | |
npixel = nrow * ncol | |
for i in range( num ): | |
buf = struct.unpack( '>%dB' % npixel, f.read( npixel ) ) | |
pixel[i, :, :] = np.asarray( buf ).reshape( ( nrow, ncol ) ) | |
f.close() | |
return pixel | |
if __name__ == '__main__': | |
print '# MNIST training data' | |
mnist = MNIST( 'L' ) | |
lab = mnist.getLabel() | |
dat = mnist.getImage() | |
print lab.shape, dat.shape | |
print '# MNIST test data' | |
mnist = MNIST( 'T' ) | |
lab = mnist.getLabel() | |
dat = mnist.getImage() | |
print lab.shape, dat.shape | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import numpy as np | |
import mnist0117 as mnist | |
nx = 10 | |
ny = 10 | |
gap = 4 | |
mn = mnist.MNIST( 'L' ) | |
dat = mn.getImage()[:nx*ny] | |
nrow, ncol = dat.shape[1:] | |
width = nx * ( ncol + gap ) + gap | |
height = ny * ( nrow + gap ) + gap | |
img = np.zeros( ( height, width ), dtype = int ) + 128 | |
for iy in range( ny ): | |
lty = iy * ( nrow + gap ) + gap | |
for ix in range( nx ): | |
ltx = ix * ( ncol + gap ) + gap | |
img[lty:lty+nrow, ltx:ltx+ncol] = dat[iy*nx+ix] | |
cv2.imwrite( 'hoge.png', img ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment