Create a gist now

Instantly share code, notes, and snippets.

@joisino /adam.py
Last active Jul 23, 2017

What would you like to do?
'''
The implementation of Adam with chainer
http://joisino.hatenablog.com/entry/2017/07/20/210000
Copyright (c) 2017 joisino
Released under the MIT license
http://opensource.org/licenses/mit-license.php
'''
import numpy as np
import chainer
from chainer import Function, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import initializers
from chainer import Link, Chain
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions
from chainer import optimizer
class AdamRule(optimizer.UpdateRule):
def __init__(self, hyperparam):
super(AdamRule, self).__init__(hyperparam)
def init_state(self, param):
self.state['t'] = 0
self.state['m'] = 0
self.state['v'] = 0
def update_core(self, param):
grad = param.grad
if grad is None:
return
self.state['m'] = self.hyperparam.beta1 * self.state['m'] + ( 1 - self.hyperparam.beta1 ) * grad
self.state['v'] = self.hyperparam.beta2 * self.state['v'] + ( 1 - self.hyperparam.beta2 ) * grad * grad
self.state['t'] += 1
m_hat = self.state['m'] / ( 1 - np.power( self.hyperparam.beta1 , self.state['t'] ) )
v_hat = self.state['v'] / ( 1 - np.power( self.hyperparam.beta2 , self.state['t'] ) )
param.data -= self.hyperparam.alpha * m_hat / ( np.sqrt( v_hat ) + self.hyperparam.eps )
class Adam(optimizer.GradientMethod):
def __init__(self,alpha=1e-3,beta1=0.9,beta2=0.999,eps=1e-8):
super(Adam, self).__init__()
self.hyperparam.alpha = alpha
self.hyperparam.beta1 = beta1
self.hyperparam.beta2 = beta2
self.hyperparam.eps = eps
def create_update_rule(self):
return AdamRule(self.hyperparam)
class MLP(Chain):
def __init__(self,n_in,n_mid,n_out):
super(MLP, self).__init__(
l1 = L.Linear(n_in, n_mid),
l2 = L.Linear(n_mid, n_mid),
l3 = L.Linear(n_mid, n_out),
)
def __call__(self, x):
h1 = F.relu(self.l1(x))
h2 = F.relu(self.l2(h1))
y = self.l3(h2)
return F.softmax(y)
in_dim = 28*28
mid_dim = 100
out_dim = 10
n_epoch = 10
train, test = datasets.get_mnist()
train_iter = iterators.SerialIterator(train, batch_size=100, shuffle=True)
test_iter = iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)
mlp = L.Classifier(MLP(in_dim, mid_dim, out_dim))
# opt = optimizers.Adam()
opt = Adam()
opt.setup(mlp)
updater = training.StandardUpdater(train_iter, opt)
trainer = training.Trainer(updater, (n_epoch, 'epoch'), out='result')
trainer.extend(extensions.Evaluator(test_iter, mlp))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment