Skip to content

Instantly share code, notes, and snippets.

@rigibun
Last active September 27, 2016 16:11
Show Gist options
  • Save rigibun/067f349bae7ec3214d217300d3256313 to your computer and use it in GitHub Desktop.
Save rigibun/067f349bae7ec3214d217300d3256313 to your computer and use it in GitHub Desktop.
import chainer
from chainer import functions as F
from chainer import links as L
from chainer import training
from chainer.training import extensions
BATCH_SIZE = 50
EPOCH = 10
class MLP(chainer.Chain):
def __init__(self):
super(MLP, self).__init__(
l1 = L.Linear(784, 1000),
l2 = L.Linear(1000, 1000),
l3 = L.Linear(1000, 10)
)
def __call__(self, x):
h = F.dropout(F.relu(self.l1(x)), ratio=0.3)
h = F.dropout(F.relu(self.l2(h)), ratio=0.3)
return self.l3(h)
model = L.Classifier(MLP())
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
train, test = chainer.datasets.get_mnist()
train_iter = chainer.iterators.SerialIterator(train, batch_size=BATCH_SIZE)
test_iter = chainer.iterators.SerialIterator(test, batch_size=BATCH_SIZE, repeat=False, shuffle=False)
updater = training.StandardUpdater(train_iter, optimizer)
trainer = training.Trainer(updater, (EPOCH, 'epoch'), out='result')
trainer.extend(extensions.Evaluator(test_iter, model))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment