Created
May 3, 2018 07:22
-
-
Save machinaut/e264799f4b88e55e9c928dc80cc1e52b to your computer and use it in GitHub Desktop.
toy rnn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# https://karpathy.github.io/2015/05/21/rnn-effectiveness/ | |
import numpy as np | |
def rnn_fwd(x, h_in, Wxh, Whh, Why, bh, by): | |
h_out = np.tanh(x.dot(Wxh) + h_in.dot(Whh) + bh) | |
y = h_out.dot(Why) + by | |
cache = (x, h_in, Wxh, Whh, Why, bh, by, h_out) | |
return h_out, y, cache | |
def rnn_bak(dh_out, dy, cache): | |
(x, h_in, Wxh, Whh, Why, bh, by, h_out) = cache | |
dby = dy.sum(axis=0, keepdims=True) | |
dWhy = h_out.T.dot(dy) | |
dh = (1 - np.square(h_out)) * (dy.dot(Why.T) + dh_out) | |
dbh = dh.sum(axis=0, keepdims=True) | |
dWxh = x.T.dot(dh) | |
dx = dh.dot(Wxh.T) | |
dWhh = h_in.T.dot(dh) | |
dh_in = dh.dot(Whh.T) | |
return dx, dh_in, dWxh, dWhh, dWhy, dbh, dby | |
class RNN: | |
def __init__(self, X=1, H=10, Y=1): | |
self.X = X | |
self.H = H | |
self.Y = Y | |
self.Wxh = np.random.randn(X, H) / np.sqrt(X) | |
self.Whh = np.random.randn(H, H) / np.sqrt(H) | |
self.Why = np.random.randn(H, Y) / np.sqrt(H) | |
self.bh = np.zeros(H) | |
self.by = np.zeros(Y) | |
self.h = np.zeros(H) | |
def fwd(self, x, h_in): | |
return rnn_fwd(x, h_in, self.Wxh, self.Whh, self.Why, self.bh, self.by) | |
def forward(self, x): | |
N, T, _ = x.shape | |
h = np.zeros((N, self.H)) | |
y = np.zeros((N, T, self.Y)) | |
cache = {} | |
for t in range(T): | |
h, y[:, t, :], cache[t] = self.fwd(x[:, t, :], h) | |
return y, cache | |
def backward(self, dy, cache): | |
N, T, _ = dy.shape | |
dx = np.zeros((N, T, self.X)) | |
dWxh = np.zeros((self.X, self.H)) | |
dWhh = np.zeros((self.H, self.H)) | |
dWhy = np.zeros((self.H, self.Y)) | |
dbh = np.zeros(self.H) | |
dby = np.zeros(self.Y) | |
th = np.zeros((N, self.H)) | |
for t in range(T): | |
dx[:, t, :], th, tWxh, tWhh, tWhy, tbh, tby = rnn_bak(th, dy[:, t, :], cache[t]) | |
dWxh += tWxh | |
dWhh += tWhh | |
dWhy += tWhy | |
dbh += tbh.squeeze() | |
dby += tby.squeeze() | |
return dx, th, dWxh, dWhh, dWhy, dbh, dby | |
def supervise(self, x, z, alpha=0.01, clip=1.0): | |
y, cache = self.forward(x) | |
dy = y - z | |
_, _, dWxh, dWhh, dWhy, dbh, dby = self.backward(dy, cache) | |
self.Wxh -= alpha * dWxh.clip(-clip, clip) | |
self.Whh -= alpha * dWhh.clip(-clip, clip) | |
self.Why -= alpha * dWhy.clip(-clip, clip) | |
self.bh -= alpha * dbh.clip(-clip, clip) | |
self.by -= alpha * dby.clip(-clip, clip) | |
return np.mean(np.square(y - z)) | |
if __name__ == '__main__': | |
N, T, X, H, Y = 1000, 100, 2, 10, 2 | |
rnn = RNN(X=X, H=H, Y=Y) | |
for i in range(1000): | |
x = np.random.randn(N, T, X) | |
z = np.zeros((N, T, Y)) | |
z[:, 1:, :] = x[:, :-1, :] | |
print(i, 'loss', rnn.supervise(x, z)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import unittest | |
import numpy as np | |
from rnn import rnn_fwd, rnn_bak | |
def multi_index_iterator(x): | |
it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite']) | |
while not it.finished: | |
yield it.multi_index | |
it.iternext() | |
def finite_difference(f, x, df, h=1e-6): | |
assert not np.issubdtype(x.dtype, np.integer) | |
x = x.copy() | |
x.setflags(write=True) | |
grad = np.zeros_like(x) | |
for ix in multi_index_iterator(x): | |
oldval = x[ix] | |
x[ix] = oldval + h | |
pos = f(x).copy() | |
x[ix] = oldval - h | |
neg = f(x).copy() | |
x[ix] = oldval | |
grad[ix] = np.sum((pos - neg) * df) / (2 * h) | |
return grad | |
class TestBackprop(unittest.TestCase): | |
def test_rnn(self): | |
rs = np.random.RandomState(0) | |
N, X, H, Y = 3, 4, 5, 6 | |
x = rs.randn(N, X) | |
h_in = rs.randn(N, H) | |
Wxh = rs.randn(X, H) | |
Whh = rs.randn(H, H) | |
Why = rs.randn(H, Y) | |
bh = rs.randn(H) | |
by = rs.randn(Y) | |
dh_out = rs.randn(N, H) | |
dy = rs.randn(N, Y) | |
for arr in (x, h_in, Wxh, Whh, Why, bh, by, dh_out, dy): | |
arr.setflags(write=False) | |
_, _, cache = rnn_fwd(x, h_in, Wxh, Whh, Why, bh, by) | |
dx, dh_in, dWxh, dWhh, dWhy, dbh, dby = rnn_bak(dh_out, dy, cache) | |
nx = np.zeros_like(dx) | |
nh_in = np.zeros_like(dh_in) | |
nWxh = np.zeros_like(dWxh) | |
nWhh = np.zeros_like(dWhh) | |
nWhy = np.zeros_like(dWhy) | |
nbh = np.zeros_like(dbh) | |
nby = np.zeros_like(dby) | |
for i, dout in enumerate((dh_out, dy)): | |
nx += finite_difference(lambda y: rnn_fwd(y, h_in, Wxh, Whh, Why, bh, by)[i], x, dout) | |
nh_in += finite_difference(lambda y: rnn_fwd(x, y, Wxh, Whh, Why, bh, by)[i], h_in, dout) # noqa | |
nWxh += finite_difference(lambda y: rnn_fwd(x, h_in, y, Whh, Why, bh, by)[i], Wxh, dout) | |
nWhh += finite_difference(lambda y: rnn_fwd(x, h_in, Wxh, y, Why, bh, by)[i], Whh, dout) | |
nWhy += finite_difference(lambda y: rnn_fwd(x, h_in, Wxh, Whh, y, bh, by)[i], Why, dout) | |
nbh += finite_difference(lambda y: rnn_fwd(x, h_in, Wxh, Whh, Why, y, by)[i], bh, dout) | |
nby += finite_difference(lambda y: rnn_fwd(x, h_in, Wxh, Whh, Why, bh, y)[i], by, dout) | |
np.testing.assert_allclose(dx, nx) | |
np.testing.assert_allclose(dh_in, nh_in) | |
np.testing.assert_allclose(dWxh, nWxh, rtol=1e-6) | |
np.testing.assert_allclose(dWhh, nWhh, rtol=1e-6) | |
np.testing.assert_allclose(dWhy, nWhy, rtol=1e-6) | |
np.testing.assert_allclose(dbh, nbh) | |
np.testing.assert_allclose(dby, nby) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment