Skip to content

Instantly share code, notes, and snippets.

@joisino
Last active July 23, 2017 00:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save joisino/31a47619349ae8f3326cc8b9641cbcc1 to your computer and use it in GitHub Desktop.
Save joisino/31a47619349ae8f3326cc8b9641cbcc1 to your computer and use it in GitHub Desktop.
'''
The implementation of Adam with chainer
http://joisino.hatenablog.com/entry/2017/07/20/210000
Copyright (c) 2017 joisino
Released under the MIT license
http://opensource.org/licenses/mit-license.php
'''
import numpy as np
import chainer
from chainer import Function, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import initializers
from chainer import Link, Chain
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions
from chainer import optimizer
class AdamRule(optimizer.UpdateRule):
def __init__(self, hyperparam):
super(AdamRule, self).__init__(hyperparam)
def init_state(self, param):
self.state['t'] = 0
self.state['m'] = 0
self.state['v'] = 0
def update_core(self, param):
grad = param.grad
if grad is None:
return
self.state['m'] = self.hyperparam.beta1 * self.state['m'] + ( 1 - self.hyperparam.beta1 ) * grad
self.state['v'] = self.hyperparam.beta2 * self.state['v'] + ( 1 - self.hyperparam.beta2 ) * grad * grad
self.state['t'] += 1
m_hat = self.state['m'] / ( 1 - np.power( self.hyperparam.beta1 , self.state['t'] ) )
v_hat = self.state['v'] / ( 1 - np.power( self.hyperparam.beta2 , self.state['t'] ) )
param.data -= self.hyperparam.alpha * m_hat / ( np.sqrt( v_hat ) + self.hyperparam.eps )
class Adam(optimizer.GradientMethod):
def __init__(self,alpha=1e-3,beta1=0.9,beta2=0.999,eps=1e-8):
super(Adam, self).__init__()
self.hyperparam.alpha = alpha
self.hyperparam.beta1 = beta1
self.hyperparam.beta2 = beta2
self.hyperparam.eps = eps
def create_update_rule(self):
return AdamRule(self.hyperparam)
class MLP(Chain):
def __init__(self,n_in,n_mid,n_out):
super(MLP, self).__init__(
l1 = L.Linear(n_in, n_mid),
l2 = L.Linear(n_mid, n_mid),
l3 = L.Linear(n_mid, n_out),
)
def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
y = self.l3(h2)
return F.softmax(y)
in_dim = 28*28
mid_dim = 100
out_dim = 10
n_epoch = 10
train, test = datasets.get_mnist()
train_iter = iterators.SerialIterator(train, batch_size=100, shuffle=True)
test_iter = iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)
mlp = L.Classifier(MLP(in_dim, mid_dim, out_dim))
# opt = optimizers.Adam()
opt = Adam()
opt.setup(mlp)
updater = training.StandardUpdater(train_iter, opt)
trainer = training.Trainer(updater, (n_epoch, 'epoch'), out='result')
trainer.extend(extensions.Evaluator(test_iter, mlp))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment