Instantly share code, notes, and snippets.

# joisino/HashedNets.py Created Jul 24, 2017

 ''' The implementation of HashedNets with chainer http://joisino.hatenablog.com/entry/2017/07/27/210000 Copyright (c) 2017 joisino Released under the MIT license http://opensource.org/licenses/mit-license.php ''' import numpy as np import chainer from chainer import Function, report, training, utils, Variable from chainer import datasets, iterators, optimizers, serializers from chainer import initializers from chainer import Link, Chain import chainer.functions as F import chainer.links as L from chainer.training import extensions import sys class HashedLinearFunction(Function): def __init__(self, n, m, K, table): super(HashedLinearFunction, self).__init__() self.n = n self.m = m self.K = K self.table = table assert bin(K).count("1") == 1, "K must be powers of 2" def hash(self, x): res = np.zeros(x.shape, dtype=np.int32) cur = x.copy() for i in range(4): res ^= self.table[cur&255] cur >>= 8 return res def forward(self, inputs): x = inputs[0] W = inputs[1] Wha = self.hash( np.arange(0,self.n*self.m) ) vW = W[Wha.reshape(self.m,self.n)] bha = self.hash( np.arange(self.n*self.m,self.n*self.m+self.m) ) b = W[bha] y = x.dot( vW.T ) + b return y, def backward(self, inputs, grad_outputs): x = inputs[0] W = inputs[1] gy = grad_outputs[0] Wha = self.hash( np.arange(0,self.n*self.m) ) vW = W[Wha.reshape(self.m,self.n)] bha = self.hash( np.arange(self.n*self.m,self.n*self.m+self.m) ) gx = gy.dot(vW) gvW = gy.T.dot(x) gb = gy.sum(0) gW = np.zeros(self.K, dtype=np.float32) gW += np.bincount(Wha, weights=gvW.reshape(-1), minlength=self.K) gW += np.bincount(bha, weights=gb, minlength=self.K) return gx, gW def hashed_linear(x, W, n, m, K, table): func = HashedLinearFunction(n, m, K, table) return func(x, W) class HashedLinear(Link): def __init__(self, n, m, K): super(HashedLinear, self).__init__() self.n = n self.m = m self.K = K self.table = np.random.randint(0, K, 256) with self.init_scope(): self.W = chainer.Parameter( np.random.randn(K).astype(np.float32) / np.sqrt( n ) ) def __call__(self, x): return hashed_linear(x, self.W, self.n, self.m, self.K, self.table) class MLP(Chain): def __init__(self,n_in,n_mid,n_out, K): self.n_in = n_in super(MLP, self).__init__( hl1 = HashedLinear(n_in, n_mid, K), hl2 = HashedLinear(n_mid, n_mid, K), hl3 = HashedLinear(n_mid, n_out, K), ) def __call__(self, x): h0 = F.reshape(x, (-1, self.n_in)) h1 = F.relu(self.hl1(h0)) h2 = F.relu(self.hl2(h1)) y = self.hl3(h2) return F.softmax(y) in_dim = 28*28 mid_dim = 100 out_dim = 10 K = 2**16 if len(sys.argv) == 2: K = 2**int(sys.argv[1]) n_epoch = 20 train, test = datasets.get_mnist() train_iter = iterators.SerialIterator(train, batch_size=100, shuffle=True) test_iter = iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False) mlp = L.Classifier(MLP(in_dim, mid_dim, out_dim, K)) opt = optimizers.Adam() opt.setup(mlp) updater = training.StandardUpdater(train_iter, opt) trainer = training.Trainer(updater, (n_epoch, 'epoch'), out='result') trainer.extend(extensions.Evaluator(test_iter, mlp)) trainer.extend(extensions.LogReport()) trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy'])) trainer.extend(extensions.ProgressBar()) trainer.run()