Last active
August 18, 2018 13:48
-
-
Save gui11aume/100838fbc931f42189a3312d19327a1a to your computer and use it in GitHub Desktop.
simple_rnn
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
import numpy as np | |
import sys | |
def softmax(x): | |
e = np.exp(x - np.max(x)) | |
return e / e.sum() | |
class Encoder: | |
def __init__(self, alphabet): | |
self.alphabet = alphabet | |
self.char_to_idx = { ch:i for i,ch in enumerate(alphabet) } | |
self.idx_to_char = { i:ch for i,ch in enumerate(alphabet) } | |
def onehot(self, c): | |
out = np.zeros(len(self.alphabet)) | |
out[self.char_to_idx[c]] = 1 | |
return out | |
def tochar(self, clist): | |
return ''.join([self.idx_to_char[c] for c in clist]) | |
class FwdParams: | |
def __init__(self, inith): | |
self.h = {-1: inith} | |
self.p = {} | |
class BackpropParams: | |
def __init__(self, xsz, hsz, ysz): | |
self.dby = np.zeros(ysz) | |
self.dbh = np.zeros(hsz) | |
self.dWy = np.zeros((ysz, hsz)) | |
self.dWh = np.zeros((hsz, hsz)) | |
self.dWx = np.zeros((hsz, xsz)) | |
class RNN: | |
''' | |
A minimalist recurrent neural network with one hidden layer. | |
The hidden layer is tanh, the outer layer is softmax. | |
''' | |
def __init__(self, xsz, hsz, ysz): | |
# Sizes of the input, hidden and output layers. | |
self.xsz = xsz | |
self.hsz = hsz | |
self.ysz = ysz | |
# Set initial parameters at random. | |
self.Wx = np.random.rand(hsz, xsz) * 0.1 - 0.05 | |
self.Wh = np.random.rand(hsz, hsz) * 0.1 - 0.05 | |
self.Wy = np.random.rand(ysz, hsz) * 0.1 - 0.05 | |
# Set initial biases to 0. | |
self.bh = np.zeros(hsz) | |
self.by = np.zeros(ysz) | |
def forward(self, inith, x): | |
fwd = FwdParams(inith) | |
for t in range(len(x)): | |
fwd.h[t] = np.tanh(np.dot(self.Wx, x[t]) + \ | |
np.dot(self.Wh, fwd.h[t-1]) + self.bh) | |
fwd.p[t] = softmax(np.dot(self.Wy, fwd.h[t]) + self.by) | |
return fwd | |
def bptt(self, fwd, x, z): | |
bwd = BackpropParams(self.xsz, self.hsz, self.ysz) | |
dh = np.zeros(self.hsz) | |
for t in reversed(range(len(x))): | |
dy = fwd.p[t] - z[t] | |
bwd.dby += dy | |
bwd.dWy += np.outer(dy, fwd.h[t]) | |
# The backpropagation through the hidden layer is not | |
# so trivial. The recurrence formula below accumulates the | |
# gradient so that it can be computed in a single pass. | |
dh = ( np.dot(self.Wh.T, dh) + np.dot(self.Wy.T, dy) ) * \ | |
( 1-fwd.h[t]**2 ) | |
bwd.dbh += dh | |
bwd.dWh += np.outer(dh, fwd.h[t-1]) | |
bwd.dWx += np.outer(dh, x[t]) | |
return bwd | |
def update(self, bwd, eps=.01): | |
self.bh += -eps * np.clip(bwd.dbh, -5, 5) | |
self.by += -eps * np.clip(bwd.dby, -5, 5) | |
self.Wx += -eps * np.clip(bwd.dWx, -5, 5) | |
self.Wh += -eps * np.clip(bwd.dWh, -5, 5) | |
self.Wy += -eps * np.clip(bwd.dWy, -5, 5) | |
def generate(self, sz): | |
idx = np.random.randint(self.xsz) | |
c = np.zeros(self.xsz) | |
c[idx] = 1 | |
out = [idx] | |
oh = np.zeros(self.hsz) | |
for k in range(sz): | |
nh = np.tanh(np.dot(self.Wx, c) + np.dot(self.Wh, oh) + self.bh) | |
prob = softmax(np.dot(self.Wy, nh) + self.by) | |
idx = np.random.choice(range(self.xsz), p=prob.ravel()) | |
c = np.zeros(self.xsz) | |
c[idx] = 1 | |
out.append(idx) | |
oh = nh | |
return out | |
def main(f): | |
txt = f.read() | |
txtlen = len(txt) | |
alphbt = sorted(list(set(txt))) | |
print 'text: %d\nalphabet: %d' % (txtlen, len(alphbt)) | |
rnn = RNN(len(alphbt), len(alphbt), len(alphbt)) | |
convert = Encoder(alphbt) | |
batchsz = 25 | |
n = 0 | |
while True: | |
h = np.zeros(rnn.hsz) | |
for pos in range(0, txtlen - batchsz, batchsz): | |
# Read characters from text. | |
x = [convert.onehot(txt[pos+i]) for i in range(batchsz)] | |
z = x[1:] + [convert.onehot(txt[pos+batchsz])] | |
# Train the RNN. | |
fwd = rnn.forward(h, x) | |
bwd = rnn.bptt(fwd, x, z) | |
rnn.update(bwd) | |
# Update internal hidden state. | |
h = fwd.h[batchsz-1] | |
n += 1 | |
if n % 50 == 0: | |
print convert.tochar(rnn.generate(300)) | |
if __name__ == '__main__': | |
np.random.seed(123) | |
with open(sys.argv[1]) as f: | |
main(f) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is a simplistic RNN implementation (Recurrent Neural Network) to learn the structure of a text character by character. This has few applications in practice because such simple RNN have only a short memory, but the logic of the main class
RNN
is useful to understand how backpropagation through time (BPTT) is performed in RNNs. The only delicate part is to figure out the recurrence relation of the gradient when it pass through a hidden layer.