Skip to content

Instantly share code, notes, and snippets.

@takatakamanbou takatakamanbou/mnist.py Secret
Last active Oct 21, 2017

Embed
What would you like to do?
from __future__ import print_function
import struct
import os
import numpy as np
class MNIST:
def __init__( self, pathMNIST = '.' ):
fnLabelL = os.path.join( pathMNIST, 'train-labels-idx1-ubyte' )
fnLabelT = os.path.join( pathMNIST, 't10k-labels-idx1-ubyte' )
self.fnLabel = { 'L': fnLabelL, 'T': fnLabelT }
fnImageL = os.path.join( pathMNIST, 'train-images-idx3-ubyte' )
fnImageT = os.path.join( pathMNIST, 't10k-images-idx3-ubyte' )
self.fnImage = { 'L': fnImageL, 'T': fnImageT }
self.nrow = 28
self.ncol = 28
self.nclass = 10
def getLabel( self, LT ):
return _readLabel( self.fnLabel[LT] )
def getImage( self, LT ):
return _readImage( self.fnImage[LT] )
##### reading the label file
#
def _readLabel( fnLabel ):
f = open( fnLabel, 'rb' )
### header (two 4B integers, magic number(2049) & number of items)
#
header = f.read( 8 )
mn, num = struct.unpack( '>2i', header ) # MSB first (bigendian)
assert mn == 2049
#print mn, num
### labels (unsigned byte)
#
label = np.array( struct.unpack( '>%dB' % num, f.read() ), dtype = int )
f.close()
return label
##### reading the image file
#
def _readImage( fnImage ):
f = open( fnImage, 'rb' )
### header (four 4B integers, magic number(2051), #images, #rows, and #cols
#
header = f.read( 16 )
mn, num, nrow, ncol = struct.unpack( '>4i', header ) # MSB first (bigendian)
assert mn == 2051
#print mn, num, nrow, ncol
### pixels (unsigned byte)
#
npixel = ncol * nrow
#pixel = np.empty( ( num, npixel ), dtype = int )
#pixel = np.empty( ( num, npixel ), dtype = np.int32 )
pixel = np.empty( ( num, npixel ) )
for i in range( num ):
buf = struct.unpack( '>%dB' % npixel, f.read( npixel ) )
pixel[i, :] = np.asarray( buf )
f.close()
return pixel
if __name__ == '__main__':
mnist = MNIST( pathMNIST = './mnist' )
print( '# MNIST training data' )
dat = mnist.getImage( 'L' )
lab = mnist.getLabel( 'L' )
print( dat.shape, dat.dtype, lab.shape )
print( '# MNIST test data' )
dat = mnist.getImage( 'T' )
lab = mnist.getLabel( 'T' )
print( dat.shape, dat.dtype, lab.shape )
### MNIST データを可視化するプログラム ###
from __future__ import print_function
import cv2
import numpy as np
import mnist
nx, ny = 16, 8 # 横縦に並べる画像の数
gap = 4 # 画像間のスペース
# MNIST の学習データのうち最初の nx * ny 個だけ取り出す
mn = mnist.MNIST(pathMNIST = './mnist')
dat = mn.getImage('L')[:nx*ny]
lab = mn.getLabel('L')[:nx*ny]
nrow, ncol = mn.nrow, mn.ncol
# 並べた画像の幅と高さ
width = nx * (ncol + gap) + gap
height = ny * (nrow + gap) + gap
# 画像の作成
img = np.zeros((height, width), dtype = int) + 128
for iy in range(ny):
lty = iy*(nrow + gap) + gap
for ix in range(nx):
ltx = ix*(ncol + gap) + gap
img[lty:lty+nrow, ltx:ltx+ncol] = dat[iy*nx+ix].reshape((nrow, ncol))
# 画像の出力
cv2.imwrite('hoge.png', img)
# ラベルの出力
print(lab.reshape((ny, nx)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.