Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
'''
The implementation of HashedNets with chainer
http://joisino.hatenablog.com/entry/2017/07/27/210000
Copyright (c) 2017 joisino
Released under the MIT license
http://opensource.org/licenses/mit-license.php
'''
import numpy as np
import chainer
from chainer import Function, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import initializers
from chainer import Link, Chain
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions
import sys
class HashedLinearFunction(Function):
def __init__(self, n, m, K, table):
super(HashedLinearFunction, self).__init__()
self.n = n
self.m = m
self.K = K
self.table = table
assert bin(K).count("1") == 1, "K must be powers of 2"
def hash(self, x):
res = np.zeros(x.shape, dtype=np.int32)
cur = x.copy()
for i in range(4):
res ^= self.table[cur&255]
cur >>= 8
return res
def forward(self, inputs):
x = inputs[0]
W = inputs[1]
Wha = self.hash( np.arange(0,self.n*self.m) )
vW = W[Wha.reshape(self.m,self.n)]
bha = self.hash( np.arange(self.n*self.m,self.n*self.m+self.m) )
b = W[bha]
y = x.dot( vW.T ) + b
return y,
def backward(self, inputs, grad_outputs):
x = inputs[0]
W = inputs[1]
gy = grad_outputs[0]
Wha = self.hash( np.arange(0,self.n*self.m) )
vW = W[Wha.reshape(self.m,self.n)]
bha = self.hash( np.arange(self.n*self.m,self.n*self.m+self.m) )
gx = gy.dot(vW)
gvW = gy.T.dot(x)
gb = gy.sum(0)
gW = np.zeros(self.K, dtype=np.float32)
gW += np.bincount(Wha, weights=gvW.reshape(-1), minlength=self.K)
gW += np.bincount(bha, weights=gb, minlength=self.K)
return gx, gW
def hashed_linear(x, W, n, m, K, table):
func = HashedLinearFunction(n, m, K, table)
return func(x, W)
class HashedLinear(Link):
def __init__(self, n, m, K):
super(HashedLinear, self).__init__()
self.n = n
self.m = m
self.K = K
self.table = np.random.randint(0, K, 256)
with self.init_scope():
self.W = chainer.Parameter( np.random.randn(K).astype(np.float32) / np.sqrt( n ) )
def __call__(self, x):
return hashed_linear(x, self.W, self.n, self.m, self.K, self.table)
class MLP(Chain):
def __init__(self,n_in,n_mid,n_out, K):
self.n_in = n_in
super(MLP, self).__init__(
hl1 = HashedLinear(n_in, n_mid, K),
hl2 = HashedLinear(n_mid, n_mid, K),
hl3 = HashedLinear(n_mid, n_out, K),
)
def __call__(self, x):
h0 = F.reshape(x, (-1, self.n_in))
h1 = F.relu(self.hl1(h0))
h2 = F.relu(self.hl2(h1))
y = self.hl3(h2)
return F.softmax(y)
in_dim = 28*28
mid_dim = 100
out_dim = 10
K = 2**16
if len(sys.argv) == 2:
K = 2**int(sys.argv[1])
n_epoch = 20
train, test = datasets.get_mnist()
train_iter = iterators.SerialIterator(train, batch_size=100, shuffle=True)
test_iter = iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)
mlp = L.Classifier(MLP(in_dim, mid_dim, out_dim, K))
opt = optimizers.Adam()
opt.setup(mlp)
updater = training.StandardUpdater(train_iter, opt)
trainer = training.Trainer(updater, (n_epoch, 'epoch'), out='result')
trainer.extend(extensions.Evaluator(test_iter, mlp))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment