Create a gist now

Instantly share code, notes, and snippets.

@takatakamanbou /cifar10.py Secret
Last active Mar 8, 2016

What would you like to do?
import numpy as np
import scipy as sp
import os
import cPickle
class CIFAR10( object ):
def __init__( self, dirname ):
self.path = dirname
f_meta = open( os.path.join( self.path, 'batches.meta'), 'r' )
self.meta = cPickle.load( f_meta )
f_meta.close()
self.nclass = len( self.meta['label_names'] )
print '##### CIFAR-10 #####'
print '# label_names =', self.meta['label_names']
print '# num_vis = ', self.meta['num_vis']
def _loadBatch( self, fn ):
p = os.path.join( self.path, fn )
f = open( p, 'r' )
d = cPickle.load( f )
f.close()
data = d['data'] # 10000 x 3072 ( 3072 = 3 x 32 x 32 ), unit8s
labels = d['labels'] # 10000-dim, in { 0, 1, ..., 9 }
return data, np.array( labels )
def _loadL( self ):
fnList = [ 'data_batch_%d' % i for i in range( 1, 6 ) ]
dataList, labelsList = [], []
for fn in fnList:
d, l = self._loadBatch( fn )
dataList.append( d )
labelsList.append( l )
return np.vstack( dataList ), np.hstack( labelsList )
def _loadT( self ):
return self._loadBatch( 'test_batch' )
##### loading the data
#
def loadData( self, LT ):
if LT == 'L':
dat, lab = self._loadL()
else:
dat, lab = self._loadT()
X = np.asarray( dat, dtype = float ).reshape( ( -1, 3, 32, 32 ) )
t = np.zeros( ( lab.shape[0], self.nclass ), dtype = bool )
for ik in range( self.nclass ):
t[lab == ik, ik] = True
return X, lab, t
##### generating the index of training & validation data
#
def genIndexLV( self, lab, seed = 0 ):
np.random.seed( seed )
idx = np.random.permutation( lab.shape[0] )
idxV = np.zeros( lab.shape[0], dtype = bool )
# selecting 1000 images per class for validation
for ik in range( self.nclass ):
i = np.where( lab[idx] == ik )[0][:1000]
idxV[i] = True
idxL = -idxV
return idxL, idxV
if __name__ == "__main__":
import cv2
dirCIFAR10 = '../140823-pylearn2/data/cifar10/cifar-10-batches-py'
cifar10 = CIFAR10( dirCIFAR10 )
dataL, labelsL = cifar10._loadL()
w = h = 32
nclass = 10
nimg = 10
gap = 4
width = nimg * ( w + gap ) + gap
height = nclass * ( h + gap ) + gap
img = np.zeros( ( height, width, 3 ), dtype = int ) + 128
for iy in range( nclass ):
lty = iy * ( h + gap ) + gap
idx = np.where( labelsL == iy )[0]
for ix in range( nimg ):
ltx = ix * ( w + gap ) + gap
tmp = dataL[idx[ix], :].reshape( ( 3, h, w ) )
# BGR <= RGB
img[lty:lty+h, ltx:ltx+w, 0] = tmp[2, :, :]
img[lty:lty+h, ltx:ltx+w, 1] = tmp[1, :, :]
img[lty:lty+h, ltx:ltx+w, 2] = tmp[0, :, :]
cv2.imwrite( 'hoge.png', img )
import numpy as np
import scipy as sp
import cv2
import cifar10
# loading CIFAR-10 data
dirCIFAR10 = '../140823-pylearn2/data/cifar10/cifar-10-batches-py'
cifar = cifar10.CIFAR10( dirCIFAR10 )
dat, lab, t = cifar.loadData( 'L' )
N = dat.shape[0]
X = dat.reshape( ( N, -1 ) ) / 255
D = X.shape[1]
xm = np.mean( X, axis = 0 )
X -= xm
# eigenvalue decomposition of the covariance matrix
C = np.dot( X.T, X ) / N
U, lam, V = np.linalg.svd( C ) # U[:, i] is the i-th eigenvector
#for i in range( D ):
# print i, lam[i], lam[i] / lam[0]
# ZCA whitening
eps = 0
sqlam = np.sqrt( lam + eps )
Uzca = np.dot( U / sqlam[np.newaxis, :], U.T )
Z = np.dot( X, Uzca.T )
# computing histograms
'''
nbin = 1000
Xh, Xbe = np.histogram( X, bins = nbin, normed = True )
Zh, Zbe = np.histogram( Z, bins = nbin, normed = True )
for i in range( nbin ):
print ( Xbe[i] + Xbe[i+1] )/2, Xh[i],
print ( Zbe[i] + Zbe[i+1] )/2, Zh[i]
'''
# making output images
w = h = 32
nclass = 10
nimg = 10
gap = 4
width = nimg * ( w + gap ) + gap
height = nclass * ( h + gap ) + gap
img = np.zeros( ( height, width, 3 ), dtype = int )
for iy in range( nclass ):
lty = iy * ( h + gap ) + gap
idx = np.where( lab == iy )[0]
for ix in range( nimg ):
ltx = ix * ( w + gap ) + gap
absmax = np.max( np.abs( Z[idx[ix], :] ) )
tmp = Z[idx[ix], :].reshape( ( 3, h, w ) ) / absmax *127 + 128
# BGR <= RGB
img[lty:lty+h, ltx:ltx+w, 0] = tmp[2, :, :]
img[lty:lty+h, ltx:ltx+w, 1] = tmp[1, :, :]
img[lty:lty+h, ltx:ltx+w, 2] = tmp[0, :, :]
cv2.imwrite( 'hoge.png', img )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment