Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@joisino
Created July 24, 2017 09:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save joisino/a83095c14190b5121468ed751d50349f to your computer and use it in GitHub Desktop.
Save joisino/a83095c14190b5121468ed751d50349f to your computer and use it in GitHub Desktop.
'''
The implementation of HashedNets with chainer
http://joisino.hatenablog.com/entry/2017/07/27/210000
Copyright (c) 2017 joisino
Released under the MIT license
http://opensource.org/licenses/mit-license.php
'''
import numpy as np
import chainer
from chainer import Function, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import initializers
from chainer import Link, Chain
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions
import sys
class HashedLinearFunction(Function):
def __init__(self, n, m, K, table):
super(HashedLinearFunction, self).__init__()
self.n = n
self.m = m
self.K = K
self.table = table
assert bin(K).count("1") == 1, "K must be powers of 2"
def hash(self, x):
res = np.zeros(x.shape, dtype=np.int32)
cur = x.copy()
for i in range(4):
res ^= self.table[cur&255]
cur >>= 8
return res
def forward(self, inputs):
x = inputs[0]
W = inputs[1]
Wha = self.hash( np.arange(0,self.n*self.m) )
vW = W[Wha.reshape(self.m,self.n)]
bha = self.hash( np.arange(self.n*self.m,self.n*self.m+self.m) )
b = W[bha]
y = x.dot( vW.T ) + b
return y,
def backward(self, inputs, grad_outputs):
x = inputs[0]
W = inputs[1]
gy = grad_outputs[0]
Wha = self.hash( np.arange(0,self.n*self.m) )
vW = W[Wha.reshape(self.m,self.n)]
bha = self.hash( np.arange(self.n*self.m,self.n*self.m+self.m) )
gx = gy.dot(vW)
gvW = gy.T.dot(x)
gb = gy.sum(0)
gW = np.zeros(self.K, dtype=np.float32)
gW += np.bincount(Wha, weights=gvW.reshape(-1), minlength=self.K)
gW += np.bincount(bha, weights=gb, minlength=self.K)
return gx, gW
def hashed_linear(x, W, n, m, K, table):
func = HashedLinearFunction(n, m, K, table)
return func(x, W)
class HashedLinear(Link):
def __init__(self, n, m, K):
super(HashedLinear, self).__init__()
self.n = n
self.m = m
self.K = K
self.table = np.random.randint(0, K, 256)
with self.init_scope():
self.W = chainer.Parameter( np.random.randn(K).astype(np.float32) / np.sqrt( n ) )
def __call__(self, x):
return hashed_linear(x, self.W, self.n, self.m, self.K, self.table)
class MLP(Chain):
def __init__(self,n_in,n_mid,n_out, K):
self.n_in = n_in
super(MLP, self).__init__(
hl1 = HashedLinear(n_in, n_mid, K),
hl2 = HashedLinear(n_mid, n_mid, K),
hl3 = HashedLinear(n_mid, n_out, K),
)
def __call__(self, x):
h0 = F.reshape(x, (-1, self.n_in))
h1 = F.relu(self.hl1(h0))
h2 = F.relu(self.hl2(h1))
y = self.hl3(h2)
return F.softmax(y)
in_dim = 28*28
mid_dim = 100
out_dim = 10
K = 2**16
if len(sys.argv) == 2:
K = 2**int(sys.argv[1])
n_epoch = 20
train, test = datasets.get_mnist()
train_iter = iterators.SerialIterator(train, batch_size=100, shuffle=True)
test_iter = iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)
mlp = L.Classifier(MLP(in_dim, mid_dim, out_dim, K))
opt = optimizers.Adam()
opt.setup(mlp)
updater = training.StandardUpdater(train_iter, opt)
trainer = training.Trainer(updater, (n_epoch, 'epoch'), out='result')
trainer.extend(extensions.Evaluator(test_iter, mlp))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())
trainer.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment