Skip to content

Instantly share code, notes, and snippets.

@trtd56
Last active August 14, 2016 22:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save trtd56/d583928ba303a27646cbb7474f2f88a8 to your computer and use it in GitHub Desktop.
Save trtd56/d583928ba303a27646cbb7474f2f88a8 to your computer and use it in GitHub Desktop.
Chainerのtrainerを使ってCIFAR-10の分類に挑戦したかった ref: http://qiita.com/trtd56/items/6f1deddc5b9d1f2d6c06
def unpickle(file):
fp = open(file, 'rb')
if sys.version_info.major == 2:
data = pickle.load(fp)
elif sys.version_info.major == 3:
data = pickle.load(fp, encoding='latin-1')
fp.close()
return data
class Cifar10Model(chainer.Chain):
def __init__(self):
super(Cifar10Model,self).__init__(
conv1 = F.Convolution2D(3, 32, 3, pad=1),
conv2 = F.Convolution2D(32, 32, 3, pad=1),
conv3 = F.Convolution2D(32, 32, 3, pad=1),
conv4 = F.Convolution2D(32, 32, 3, pad=1),
conv5 = F.Convolution2D(32, 32, 3, pad=1),
conv6 = F.Convolution2D(32, 32, 3, pad=1),
l1 = L.Linear(512, 512),
l2 = L.Linear(512,10))
def __call__(self, x, train=True):
h = F.relu(self.conv1(x))
h = F.max_pooling_2d(F.relu(self.conv2(h)), 2)
h = F.relu(self.conv3(h))
h = F.max_pooling_2d(F.relu(self.conv4(h)), 2)
h = F.relu(self.conv5(h))
h = F.max_pooling_2d(F.relu(self.conv6(h)), 2)
h = F.dropout(F.relu(self.l1(h)), train=train)
return self.l2(h)
train_iter = chainer.iterators.SerialIterator(train, 100)
test_iter = chainer.iterators.SerialIterator(test, 100,repeat=False, shuffle=False)
train = chainer.tuple_dataset.TupleDataset(train_data, train_label)
x_train = None
y_train = []
for i in range(1,6):
data_dic = unpickle("cifar-10-batches-py/data_batch_{}".format(i))
if i == 1:
x_train = data_dic['data']
else:
x_train = np.vstack((x_train, data_dic['data']))
y_train += data_dic['labels']
test_data_dic = unpickle("cifar-10-batches-py/test_batch")
x_test = test_data_dic['data']
x_test = x_test.reshape(len(x_test),3,32,32)
y_test = np.array(test_data_dic['labels'])
x_train = x_train.reshape((len(x_train),3, 32, 32))
y_train = np.array(y_train)
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)
x_train /= 255
x_test/=255
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)
train = tuple_dataset.TupleDataset(x_train, y_train)
test = tuple_dataset.TupleDataset(x_test, y_test)
model = L.Classifier(Cifar10Model())
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
train_iter = chainer.iterators.SerialIterator(train, 100)
test_iter = chainer.iterators.SerialIterator(test, 100,repeat=False, shuffle=False)
updater = training.StandardUpdater(train_iter, optimizer, device=-1)
trainer = training.Trainer(updater, (40, 'epoch'), out="logs")
trainer.extend(extensions.Evaluator(test_iter, model, device=-1))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment