Skip to content

Instantly share code, notes, and snippets.

@joisino
Last active July 23, 2017 00:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save joisino/b3e705a9cf9a022fce2c14dbe0495d7f to your computer and use it in GitHub Desktop.
Save joisino/b3e705a9cf9a022fce2c14dbe0495d7f to your computer and use it in GitHub Desktop.
Batch Normalization
'''
Implementation of batch normalization with chainer
http://joisino.hatenablog.com/entry/2017/07/09/210000
Copyright (c) 2017 joisino
Released under the MIT license
http://opensource.org/licenses/mit-license.php
'''
import numpy as np
import chainer
from chainer import Function, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions
class BatchNormalization(Link):
def __init__(self, sz):
super(BatchNormalization, self).__init__()
self.eps = 0.00001
with self.init_scope():
self.beta = chainer.Parameter(np.zeros(sz, dtype=np.float32))
self.gamma = chainer.Parameter(np.ones(sz, dtype=np.float32))
def __call__(self, x):
mu = F.average(x, axis=0)
sigma = F.average((x-F.tile(mu,(x.data.shape[0],1)))**2, axis=0)
x_hat = (x-F.tile(mu,(x.data.shape[0],1)))/F.sqrt(F.tile(sigma+self.eps,(x.data.shape[0],1)))
y = F.tile(self.gamma,(x.data.shape[0],1)) * x_hat + F.tile(self.beta,(x.data.shape[0],1))
return y
class MLP_BN(Chain):
def __init__(self,n_in,n_mid,n_out):
super(MLP_BN, self).__init__(
l1 = L.Linear(n_in, n_mid),
l2 = L.Linear(n_mid, n_mid),
l3 = L.Linear(n_mid, n_out),
bn1 = BatchNormalization(n_mid),
bn2 = BatchNormalization(n_mid),
)
def __call__(self, x):
# h1 = F.relu(self.l1(x))
# h2 = F.relu(self.l2(h1))
h1 = self.bn1(F.relu(self.l1(x)))
h2 = self.bn2(F.relu(self.l2(h1)))
y = self.l3(h2)
return F.softmax(y)
in_dim = 28*28
mid_dim = 100
out_dim = 10
n_epoch = 20
train, test = datasets.get_mnist()
train_iter = iterators.SerialIterator(train, batch_size=100, shuffle=True)
test_iter = iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)
MLP = L.Classifier(MLP_BN(in_dim, mid_dim, out_dim))
opt = optimizers.Adam()
opt.setup(MLP)
updater = training.StandardUpdater(train_iter, opt)
trainer = training.Trainer(updater, (n_epoch, 'epoch'), out='result')
trainer.extend(extensions.Evaluator(test_iter, MLP))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment