Skip to content

Instantly share code, notes, and snippets.

@machinaut
Created May 3, 2018 07:22
Show Gist options
  • Save machinaut/e264799f4b88e55e9c928dc80cc1e52b to your computer and use it in GitHub Desktop.
Save machinaut/e264799f4b88e55e9c928dc80cc1e52b to your computer and use it in GitHub Desktop.
toy rnn
#!/usr/bin/env python
# https://karpathy.github.io/2015/05/21/rnn-effectiveness/
import numpy as np
def rnn_fwd(x, h_in, Wxh, Whh, Why, bh, by):
h_out = np.tanh(x.dot(Wxh) + h_in.dot(Whh) + bh)
y = h_out.dot(Why) + by
cache = (x, h_in, Wxh, Whh, Why, bh, by, h_out)
return h_out, y, cache
def rnn_bak(dh_out, dy, cache):
(x, h_in, Wxh, Whh, Why, bh, by, h_out) = cache
dby = dy.sum(axis=0, keepdims=True)
dWhy = h_out.T.dot(dy)
dh = (1 - np.square(h_out)) * (dy.dot(Why.T) + dh_out)
dbh = dh.sum(axis=0, keepdims=True)
dWxh = x.T.dot(dh)
dx = dh.dot(Wxh.T)
dWhh = h_in.T.dot(dh)
dh_in = dh.dot(Whh.T)
return dx, dh_in, dWxh, dWhh, dWhy, dbh, dby
class RNN:
def __init__(self, X=1, H=10, Y=1):
self.X = X
self.H = H
self.Y = Y
self.Wxh = np.random.randn(X, H) / np.sqrt(X)
self.Whh = np.random.randn(H, H) / np.sqrt(H)
self.Why = np.random.randn(H, Y) / np.sqrt(H)
self.bh = np.zeros(H)
self.by = np.zeros(Y)
self.h = np.zeros(H)
def fwd(self, x, h_in):
return rnn_fwd(x, h_in, self.Wxh, self.Whh, self.Why, self.bh, self.by)
def forward(self, x):
N, T, _ = x.shape
h = np.zeros((N, self.H))
y = np.zeros((N, T, self.Y))
cache = {}
for t in range(T):
h, y[:, t, :], cache[t] = self.fwd(x[:, t, :], h)
return y, cache
def backward(self, dy, cache):
N, T, _ = dy.shape
dx = np.zeros((N, T, self.X))
dWxh = np.zeros((self.X, self.H))
dWhh = np.zeros((self.H, self.H))
dWhy = np.zeros((self.H, self.Y))
dbh = np.zeros(self.H)
dby = np.zeros(self.Y)
th = np.zeros((N, self.H))
for t in range(T):
dx[:, t, :], th, tWxh, tWhh, tWhy, tbh, tby = rnn_bak(th, dy[:, t, :], cache[t])
dWxh += tWxh
dWhh += tWhh
dWhy += tWhy
dbh += tbh.squeeze()
dby += tby.squeeze()
return dx, th, dWxh, dWhh, dWhy, dbh, dby
def supervise(self, x, z, alpha=0.01, clip=1.0):
y, cache = self.forward(x)
dy = y - z
_, _, dWxh, dWhh, dWhy, dbh, dby = self.backward(dy, cache)
self.Wxh -= alpha * dWxh.clip(-clip, clip)
self.Whh -= alpha * dWhh.clip(-clip, clip)
self.Why -= alpha * dWhy.clip(-clip, clip)
self.bh -= alpha * dbh.clip(-clip, clip)
self.by -= alpha * dby.clip(-clip, clip)
return np.mean(np.square(y - z))
if __name__ == '__main__':
N, T, X, H, Y = 1000, 100, 2, 10, 2
rnn = RNN(X=X, H=H, Y=Y)
for i in range(1000):
x = np.random.randn(N, T, X)
z = np.zeros((N, T, Y))
z[:, 1:, :] = x[:, :-1, :]
print(i, 'loss', rnn.supervise(x, z))
#!/usr/bin/env python
import unittest
import numpy as np
from rnn import rnn_fwd, rnn_bak
def multi_index_iterator(x):
it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
while not it.finished:
yield it.multi_index
it.iternext()
def finite_difference(f, x, df, h=1e-6):
assert not np.issubdtype(x.dtype, np.integer)
x = x.copy()
x.setflags(write=True)
grad = np.zeros_like(x)
for ix in multi_index_iterator(x):
oldval = x[ix]
x[ix] = oldval + h
pos = f(x).copy()
x[ix] = oldval - h
neg = f(x).copy()
x[ix] = oldval
grad[ix] = np.sum((pos - neg) * df) / (2 * h)
return grad
class TestBackprop(unittest.TestCase):
def test_rnn(self):
rs = np.random.RandomState(0)
N, X, H, Y = 3, 4, 5, 6
x = rs.randn(N, X)
h_in = rs.randn(N, H)
Wxh = rs.randn(X, H)
Whh = rs.randn(H, H)
Why = rs.randn(H, Y)
bh = rs.randn(H)
by = rs.randn(Y)
dh_out = rs.randn(N, H)
dy = rs.randn(N, Y)
for arr in (x, h_in, Wxh, Whh, Why, bh, by, dh_out, dy):
arr.setflags(write=False)
_, _, cache = rnn_fwd(x, h_in, Wxh, Whh, Why, bh, by)
dx, dh_in, dWxh, dWhh, dWhy, dbh, dby = rnn_bak(dh_out, dy, cache)
nx = np.zeros_like(dx)
nh_in = np.zeros_like(dh_in)
nWxh = np.zeros_like(dWxh)
nWhh = np.zeros_like(dWhh)
nWhy = np.zeros_like(dWhy)
nbh = np.zeros_like(dbh)
nby = np.zeros_like(dby)
for i, dout in enumerate((dh_out, dy)):
nx += finite_difference(lambda y: rnn_fwd(y, h_in, Wxh, Whh, Why, bh, by)[i], x, dout)
nh_in += finite_difference(lambda y: rnn_fwd(x, y, Wxh, Whh, Why, bh, by)[i], h_in, dout) # noqa
nWxh += finite_difference(lambda y: rnn_fwd(x, h_in, y, Whh, Why, bh, by)[i], Wxh, dout)
nWhh += finite_difference(lambda y: rnn_fwd(x, h_in, Wxh, y, Why, bh, by)[i], Whh, dout)
nWhy += finite_difference(lambda y: rnn_fwd(x, h_in, Wxh, Whh, y, bh, by)[i], Why, dout)
nbh += finite_difference(lambda y: rnn_fwd(x, h_in, Wxh, Whh, Why, y, by)[i], bh, dout)
nby += finite_difference(lambda y: rnn_fwd(x, h_in, Wxh, Whh, Why, bh, y)[i], by, dout)
np.testing.assert_allclose(dx, nx)
np.testing.assert_allclose(dh_in, nh_in)
np.testing.assert_allclose(dWxh, nWxh, rtol=1e-6)
np.testing.assert_allclose(dWhh, nWhh, rtol=1e-6)
np.testing.assert_allclose(dWhy, nWhy, rtol=1e-6)
np.testing.assert_allclose(dbh, nbh)
np.testing.assert_allclose(dby, nby)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment