-
-
Save takatakamanbou/35d12aaa81b91d8d7e2c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import struct | |
import os | |
import numpy as np | |
class MNIST: | |
def __init__( self, pathMNIST = '.' ): | |
fnLabelL = os.path.join( pathMNIST, 'train-labels-idx1-ubyte' ) | |
fnLabelT = os.path.join( pathMNIST, 't10k-labels-idx1-ubyte' ) | |
self.fnLabel = { 'L': fnLabelL, 'T': fnLabelT } | |
fnImageL = os.path.join( pathMNIST, 'train-images-idx3-ubyte' ) | |
fnImageT = os.path.join( pathMNIST, 't10k-images-idx3-ubyte' ) | |
self.fnImage = { 'L': fnImageL, 'T': fnImageT } | |
self.nrow = 28 | |
self.ncol = 28 | |
self.nclass = 10 | |
def getLabel( self, LT ): | |
return _readLabel( self.fnLabel[LT] ) | |
def getImage( self, LT ): | |
return _readImage( self.fnImage[LT] ) | |
##### reading the label file | |
# | |
def _readLabel( fnLabel ): | |
f = open( fnLabel, 'rb' ) | |
### header (two 4B integers, magic number(2049) & number of items) | |
# | |
header = f.read( 8 ) | |
mn, num = struct.unpack( '>2i', header ) # MSB first (bigendian) | |
assert mn == 2049 | |
#print mn, num | |
### labels (unsigned byte) | |
# | |
label = np.array( struct.unpack( '>%dB' % num, f.read() ), dtype = int ) | |
f.close() | |
return label | |
##### reading the image file | |
# | |
def _readImage( fnImage ): | |
f = open( fnImage, 'rb' ) | |
### header (four 4B integers, magic number(2051), #images, #rows, and #cols | |
# | |
header = f.read( 16 ) | |
mn, num, nrow, ncol = struct.unpack( '>4i', header ) # MSB first (bigendian) | |
assert mn == 2051 | |
#print mn, num, nrow, ncol | |
### pixels (unsigned byte) | |
# | |
npixel = ncol * nrow | |
#pixel = np.empty( ( num, npixel ), dtype = int ) | |
#pixel = np.empty( ( num, npixel ), dtype = np.int32 ) | |
pixel = np.empty( ( num, npixel ) ) | |
for i in range( num ): | |
buf = struct.unpack( '>%dB' % npixel, f.read( npixel ) ) | |
pixel[i, :] = np.asarray( buf ) | |
f.close() | |
return pixel | |
if __name__ == '__main__': | |
mnist = MNIST( pathMNIST = './mnist' ) | |
print( '# MNIST training data' ) | |
dat = mnist.getImage( 'L' ) | |
lab = mnist.getLabel( 'L' ) | |
print( dat.shape, dat.dtype, lab.shape ) | |
print( '# MNIST test data' ) | |
dat = mnist.getImage( 'T' ) | |
lab = mnist.getLabel( 'T' ) | |
print( dat.shape, dat.dtype, lab.shape ) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### MNIST データを可視化するプログラム ### | |
from __future__ import print_function | |
import cv2 | |
import numpy as np | |
import mnist | |
nx, ny = 16, 8 # 横縦に並べる画像の数 | |
gap = 4 # 画像間のスペース | |
# MNIST の学習データのうち最初の nx * ny 個だけ取り出す | |
mn = mnist.MNIST(pathMNIST = './mnist') | |
dat = mn.getImage('L')[:nx*ny] | |
lab = mn.getLabel('L')[:nx*ny] | |
nrow, ncol = mn.nrow, mn.ncol | |
# 並べた画像の幅と高さ | |
width = nx * (ncol + gap) + gap | |
height = ny * (nrow + gap) + gap | |
# 画像の作成 | |
img = np.zeros((height, width), dtype = int) + 128 | |
for iy in range(ny): | |
lty = iy*(nrow + gap) + gap | |
for ix in range(nx): | |
ltx = ix*(ncol + gap) + gap | |
img[lty:lty+nrow, ltx:ltx+ncol] = dat[iy*nx+ix].reshape((nrow, ncol)) | |
# 画像の出力 | |
cv2.imwrite('hoge.png', img) | |
# ラベルの出力 | |
print(lab.reshape((ny, nx))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment