Skip to content

Instantly share code, notes, and snippets.

@szdr
Last active March 21, 2016 16:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save szdr/1934e4895cec1b70bb8c to your computer and use it in GitHub Desktop.
Save szdr/1934e4895cec1b70bb8c to your computer and use it in GitHub Desktop.
import numpy as np
from chainer import Function, Variable, optimizers
from chainer import Chain
import chainer.functions as F
import chainer.links as L
class NN(Chain):
def __init__(self):
initial_W1 = np.array([[.1, .1, .1], [.2, .2, .2]], dtype=np.float32)
initial_W2 = np.array([[1, 2], [3, 4]], dtype=np.float32)
super(NN, self).__init__(
l1=L.Linear(3, 2, nobias=True, initialW=initial_W1),
l2=L.Linear(2, 2, nobias=True, initialW=initial_W2)
)
def __call__(self, x):
a_1 = self.l1(x)
print("a_1\n{}".format(a_1.data))
z = F.sigmoid(a_1)
print("z\n{}".format(z.data))
a_2 = self.l2(z)
print("a_2\n{}".format(a_2.data))
y = a_2
print("y\n{}".format(y.data))
return y
class SquaredError(Function):
def forward(self, inputs):
x0, x1 = inputs
self.diff = x0 - x1
diff = self.diff.ravel()
return np.array(diff.dot(diff) / 2.),
def backward(self, inputs, gy):
gx0 = self.diff
return gx0, -gx0
def squared_error(x0, x1):
return SquaredError()(x0, x1)
if __name__ == '__main__':
nn = NN()
x = Variable(np.array([[1, 2, 3]], dtype=np.float32))
t = Variable(np.array([[0, 1]], dtype=np.float32))
y = nn(x)
optimizer = optimizers.SGD(lr=0.1)
optimizer.setup(nn)
nn.zerograds()
loss = squared_error(y, t)
loss.backward()
print("loss\n{}".format(loss.data))
print("W^1 grad\n{}".format(nn.l1.W.grad))
print("W^2 grad\n{}".format(nn.l2.W.grad))
print("before W^1\n{}".format(nn.l1.W.data))
print("before W^2\n{}".format(nn.l2.W.data))
optimizer.update()
print("after W^1\n{}".format(nn.l1.W.data))
print("after W^2\n{}".format(nn.l2.W.data))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment