Skip to content

Instantly share code, notes, and snippets.

@trtd56
Last active August 9, 2016 18:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save trtd56/2ce7808a271cb5cbc35402a8a90bab49 to your computer and use it in GitHub Desktop.
Save trtd56/2ce7808a271cb5cbc35402a8a90bab49 to your computer and use it in GitHub Desktop.
ChainerでAutoencoder(+ trainerの使い方の備忘録) ref: http://qiita.com/trtd56/items/acf42277c29b57c05651
class Autoencoder(chainer.Chain):
def __init__(self):
super(Autoencoder, self).__init__(
encoder = L.Linear(784, 64),
decoder = L.Linear(64, 784))
def __call__(self, x, hidden=False):
h = F.relu(self.encoder(x))
if hidden:
return h
else:
return F.relu(self.decoder(h))
# MNISTのデータの読み込み
train, test = chainer.datasets.get_mnist()
# 教師データ
train = train[0:1000]
train = [i[0] for i in train]
train = tuple_dataset.TupleDataset(train, train)
train_iter = chainer.iterators.SerialIterator(train, 100)
# テスト用データ
test = test[0:25]
model = L.Classifier(Autoencoder(), lossfun=F.mean_squared_error)
model.compute_accuracy = False
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
updater = training.StandardUpdater(train_iter, optimizer, device=-1)
trainer = training.Trainer(updater, (N_EPOCH, 'epoch'), out="result")
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport( ['epoch', 'main/loss']))
trainer.extend(extensions.ProgressBar())
trainer.run()
def plot_mnist_data(samples):
for index, (data, label) in enumerate(samples):
plt.subplot(5, 5, index + 1)
plt.axis('off')
plt.imshow(data.reshape(28, 28), cmap=cm.gray_r, interpolation='nearest')
n = int(label)
plt.title(n, color='red')
plt.show()
pred_list = []
for (data, label) in test:
pred_data = model.predictor(np.array([data]).astype(np.float32)).data
pred_list.append((pred_data, label))
plot_mnist_data(pred_list)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment