Skip to content

Instantly share code, notes, and snippets.

@karpathy
Last active March 18, 2024 20:43
Star You must be signed in to star a gist
Save karpathy/d4dee566867f8291f086 to your computer and use it in GitHub Desktop.
Minimal character-level language model with a Vanilla Recurrent Neural Network, in Python/numpy
"""
Minimal character-level Vanilla RNN model. Written by Andrej Karpathy (@karpathy)
BSD License
"""
import numpy as np
# data I/O
data = open('input.txt', 'r').read() # should be simple plain text file
chars = list(set(data))
data_size, vocab_size = len(data), len(chars)
print 'data has %d characters, %d unique.' % (data_size, vocab_size)
char_to_ix = { ch:i for i,ch in enumerate(chars) }
ix_to_char = { i:ch for i,ch in enumerate(chars) }
# hyperparameters
hidden_size = 100 # size of hidden layer of neurons
seq_length = 25 # number of steps to unroll the RNN for
learning_rate = 1e-1
# model parameters
Wxh = np.random.randn(hidden_size, vocab_size)*0.01 # input to hidden
Whh = np.random.randn(hidden_size, hidden_size)*0.01 # hidden to hidden
Why = np.random.randn(vocab_size, hidden_size)*0.01 # hidden to output
bh = np.zeros((hidden_size, 1)) # hidden bias
by = np.zeros((vocab_size, 1)) # output bias
def lossFun(inputs, targets, hprev):
"""
inputs,targets are both list of integers.
hprev is Hx1 array of initial hidden state
returns the loss, gradients on model parameters, and last hidden state
"""
xs, hs, ys, ps = {}, {}, {}, {}
hs[-1] = np.copy(hprev)
loss = 0
# forward pass
for t in xrange(len(inputs)):
xs[t] = np.zeros((vocab_size,1)) # encode in 1-of-k representation
xs[t][inputs[t]] = 1
hs[t] = np.tanh(np.dot(Wxh, xs[t]) + np.dot(Whh, hs[t-1]) + bh) # hidden state
ys[t] = np.dot(Why, hs[t]) + by # unnormalized log probabilities for next chars
ps[t] = np.exp(ys[t]) / np.sum(np.exp(ys[t])) # probabilities for next chars
loss += -np.log(ps[t][targets[t],0]) # softmax (cross-entropy loss)
# backward pass: compute gradients going backwards
dWxh, dWhh, dWhy = np.zeros_like(Wxh), np.zeros_like(Whh), np.zeros_like(Why)
dbh, dby = np.zeros_like(bh), np.zeros_like(by)
dhnext = np.zeros_like(hs[0])
for t in reversed(xrange(len(inputs))):
dy = np.copy(ps[t])
dy[targets[t]] -= 1 # backprop into y. see http://cs231n.github.io/neural-networks-case-study/#grad if confused here
dWhy += np.dot(dy, hs[t].T)
dby += dy
dh = np.dot(Why.T, dy) + dhnext # backprop into h
dhraw = (1 - hs[t] * hs[t]) * dh # backprop through tanh nonlinearity
dbh += dhraw
dWxh += np.dot(dhraw, xs[t].T)
dWhh += np.dot(dhraw, hs[t-1].T)
dhnext = np.dot(Whh.T, dhraw)
for dparam in [dWxh, dWhh, dWhy, dbh, dby]:
np.clip(dparam, -5, 5, out=dparam) # clip to mitigate exploding gradients
return loss, dWxh, dWhh, dWhy, dbh, dby, hs[len(inputs)-1]
def sample(h, seed_ix, n):
"""
sample a sequence of integers from the model
h is memory state, seed_ix is seed letter for first time step
"""
x = np.zeros((vocab_size, 1))
x[seed_ix] = 1
ixes = []
for t in xrange(n):
h = np.tanh(np.dot(Wxh, x) + np.dot(Whh, h) + bh)
y = np.dot(Why, h) + by
p = np.exp(y) / np.sum(np.exp(y))
ix = np.random.choice(range(vocab_size), p=p.ravel())
x = np.zeros((vocab_size, 1))
x[ix] = 1
ixes.append(ix)
return ixes
n, p = 0, 0
mWxh, mWhh, mWhy = np.zeros_like(Wxh), np.zeros_like(Whh), np.zeros_like(Why)
mbh, mby = np.zeros_like(bh), np.zeros_like(by) # memory variables for Adagrad
smooth_loss = -np.log(1.0/vocab_size)*seq_length # loss at iteration 0
while True:
# prepare inputs (we're sweeping from left to right in steps seq_length long)
if p+seq_length+1 >= len(data) or n == 0:
hprev = np.zeros((hidden_size,1)) # reset RNN memory
p = 0 # go from start of data
inputs = [char_to_ix[ch] for ch in data[p:p+seq_length]]
targets = [char_to_ix[ch] for ch in data[p+1:p+seq_length+1]]
# sample from the model now and then
if n % 100 == 0:
sample_ix = sample(hprev, inputs[0], 200)
txt = ''.join(ix_to_char[ix] for ix in sample_ix)
print '----\n %s \n----' % (txt, )
# forward seq_length characters through the net and fetch gradient
loss, dWxh, dWhh, dWhy, dbh, dby, hprev = lossFun(inputs, targets, hprev)
smooth_loss = smooth_loss * 0.999 + loss * 0.001
if n % 100 == 0: print 'iter %d, loss: %f' % (n, smooth_loss) # print progress
# perform parameter update with Adagrad
for param, dparam, mem in zip([Wxh, Whh, Why, bh, by],
[dWxh, dWhh, dWhy, dbh, dby],
[mWxh, mWhh, mWhy, mbh, mby]):
mem += dparam * dparam
param += -learning_rate * dparam / np.sqrt(mem + 1e-8) # adagrad update
p += seq_length # move data pointer
n += 1 # iteration counter
@Rumi4
Copy link

Rumi4 commented Jul 15, 2019

does the program generate new text or the one same from the input file

@pzdkn
Copy link

pzdkn commented Jul 20, 2019

Stupid question : In which lines does the vanishing gradient problem manifest itself ?

@javaidnabi31
Copy link

javaidnabi31 commented Jul 21, 2019

Please read my below post, discussing the implementation step by step:
https://towardsdatascience.com/recurrent-neural-networks-rnns-3f06d7653a85

@Mahawi3211
Copy link

Mahawi3211 commented Jul 22, 2019 via email

@hwdong-net
Copy link

I didn't see wheren or when to stop or break "the while True:" loop?

@BlockListed
Copy link

I have over 630k iterations and my loss is 42.9 I'm using shakespears poems as training data am I doing something wrong?

@mkffl
Copy link

mkffl commented Sep 22, 2019

To anyone finding the code hard to understand, I provide detailed explanations here: https://mkffl.github.io/
Hope it helps

@Mahawi3211
Copy link

Mahawi3211 commented Sep 25, 2019 via email

@BiaChaudhry
Copy link

@karpathy thanks a lot for this posting! I tried to run it with your hello example ... and this is what I get

Traceback (most recent call last):
File "min-char-rnn.py", line 100, in
loss, dWxh, dWhh, dWhy, dbh, dby, hprev = lossFun(inputs, targets, hprev)
File "min-char-rnn.py", line 43, in lossFun
loss += -np.log(ps[t][targets[t],0]) # softmax (cross-entropy loss)
IndexError: list index out of range

Any ideas?
Thanks!

This is only working for a file with at least a few sentences. I think the reason is that RNN works with a large dataset. For the smaller datasets, we don't need to use it.

@suncerock
Copy link

Thanks for your sharing! It helps a lot!

@lanttu1243
Copy link

I switched numpy to cupy and when I run this code at iteration 100 I get the following error that is caused by something turning into Nan during training. What can I do to fix this?
Traceback (most recent call last):
File "C:\...\hello.py", line 168, in <module>
sample_ix = sample(hprev, inputs[0], 100)
File "C:\...\hello.py", line 130, in sample
ix = int(np.random.choice(range(vocab_size), (1, 1), p=p))
File "C:\...\cupy\random\sample.py", line 196, in choice
return rs.choice(a, size, replace, p)
File "C:\...\cupy\random\generator.py", line 982, in choice
raise ValueError('probabilities are not non-negative')
ValueError: probabilities are not non-negative

@djamelherbadji
Copy link

I want the paython code of neurel network where: input layer part is composed of two neurons, . The hidden layer is constituted of two under-layers of 20 and 10 neurons for the first under-layer and the second under-layer respectively. The output layer is composed of 5 neurons.

@ehzawad
Copy link

ehzawad commented Aug 23, 2020

Traceback (most recent call last):
File "minimal_rnn.py", line 100, in
loss, dWxh, dWhh, dWhy, dbh, dby, hprev = lossFun(inputs, targets, hprev)
File "minimal_rnn.py", line 43, in lossFun
loss += -np.log(ps[t][targets[t],0]) # softmax (cross-entropy loss)
IndexError: list index out of range

Why am I getting this error?

@MistryWoman
Copy link

Can somebody explain why we need to multiply 0.01 when initializing the weights ? Arent we already using np.random.randn to sample from a normal distribution?
I am talking about this specific line of code
Wxh = np.random.randn(hidden_size, vocab_size)*0.01

@avirichie
Copy link

@MistryWoman Its essential to break the symmetry here. Since the outputs of the succeeding weights depends on the sum of inputs multiplied by the corresponding weight. If all weights are assigned to zero, every hidden unit will get same the type of signal ie zero in this case. No matter what was the input - if all weights are the same, all units in hidden layer will be the same too. This is not something we desire because we want different hidden units to compute different functions. However, this is not possible if you initialize all to the same value.

@j8ahmed
Copy link

j8ahmed commented Oct 17, 2020

Wonderful code,but there are somecodes confused me.

# perform parameter update with Adagrad
    for param, dparam, mem in zip([Wxh, Whh, Why, bh, by], 
                                  [dWxh, dWhh, dWhy, dbh, dby], 
                                [mWxh, mWhh, mWhy, mbh, mby]):
        mem += dparam * dparam# 梯度的累加
        param += -learning_rate * dparam / np.sqrt(mem + 1e-8) # adagrad update

Wxh,..mby were defined as global variables,while param,dparam,mem are just local variables.How could Adagrad update change the value of global variables?I tried to test my thought by code like below.

import numpy as np
Wxh, Whh, Why, bh, by=1,2,3,4,5
dWxh, dWhh, dWhy, dbh, dby=1,2,3,4,5
mWxh, mWhh, mWhy, mbh, mby=1,2,3,4,5
while True:
    for param, dparam, mem in zip([Wxh, Whh, Why, bh, by], 
                                  [dWxh, dWhh, dWhy, dbh, dby], 
                                [mWxh, mWhh, mWhy, mbh, mby]):
        mem += dparam * dparam# 梯度的累加
        param += dparam / np.sqrt(mem + 1e-8) 
    print(Wxh)#Output never change!

It's might be a very simple problem.Just puzzled me!Someone helps?Thanks a lot.

Not sure if anyone ever addressed this further down the thread but I found an answer on stack overflow that addresses modifying mutable global variables like Wxh and the like within a local scope.

'param' is referencing all the weight arrays rather than copying their values in their original locations. It is not copying them into a new storage location. So when adding to all the values in 'param' with adagrad we are actually modifying the global values of Wxh, etc... not the values in param. I think?

https://stackoverflow.com/questions/28566563/global-scope-when-accessing-array-element-inside-function

@adityagupta679
Copy link

@MistryWoman We are trying to push our weights closer to zero, so that gradient doesn't vanish. For large weights , gradients become very small while using tanh. Please take a look at https://gist.github.com/karpathy/d4dee566867f8291f086#gistcomment-2180482

@iVishalr
Copy link

@karpathy Thank you for posting this. I had a small doubt in line 58. Why do we perform Whh.T in np.dot(Whh.T,dhraw)? Since Whh has (hidden_size,hidden_size) as shape, we can directly multiply it with dhraw. Or am I missing something here?

@smita-09
Copy link

Traceback (most recent call last):
File "minimal_rnn.py", line 100, in
loss, dWxh, dWhh, dWhy, dbh, dby, hprev = lossFun(inputs, targets, hprev)
File "minimal_rnn.py", line 43, in lossFun
loss += -np.log(ps[t][targets[t],0]) # softmax (cross-entropy loss)
IndexError: list index out of range

Why am I getting this error?

I was getting the same error but it seems like the reason is that your input file is too small. Try writing more stuff in your input.txt

@lanttu1243
Copy link

I made a two layer recurrent neural network based off of this and I am not sure why it does not work. So, if anyone could check out and make a PR if you find a problem?
https://github.com/lanttu1243/vanilla_recurrent_neural_network.git

@anantinfinity9796
Copy link

Hi, can anyone explain the line where we pass in the gradient through the softmax function " dy [targets[t]] -=1 ". Why are we doing this operation ??

@sashakttripathi
Copy link

To anyone finding the code hard to understand, I provide detailed explanations here: https://mkffl.github.io/ Hope it helps

Good explanation for second part of lossFunction.

@pjoer
Copy link

pjoer commented Jan 31, 2023

Could someone explain how to use this? I can't generate text. It is running though.

@ijkilchenko
Copy link

@pjoer, do you have an input.txt file (really any text file) in the same directory of this file which you should be running like python min-char-rnn.py

It seems the above was written with python2 in mind, but I just updated a few lines and got it working on Python 3.9.6: https://gist.github.com/ijkilchenko/84be862a5e18240c59b4505177c9c34c

Good luck!

@pursuemoon
Copy link

pursuemoon commented Sep 24, 2023

I have a problem in the code of line 58: dhnext = np.dot(Whh.T, dhraw). Could anyone tell me what it means?

The expression for forward propagation is:

$$ \begin{align} z_t & = W_{hh} \cdot h_{t-1} + W_{xh} \cdot x_t + b_h \tag{1}\\ h_t & = Tanh(z_t) \tag{2}\\ y_t & = W_{hy} \cdot h_t + b_y \tag{3}\\ o_t & = Softmax(y_t) \tag{4}\\ \end{align} $$

Here is the gradients expression of weights that I derived:

$$ \begin{align} \frac{\partial L}{\partial W_{hy}} & = \frac{\partial L}{\partial y_t} \otimes {h_t}^T \tag{a}\\ \frac{\partial L}{\partial W_{hh}} & = {\color{red} {W_{hy}}^T \cdot \frac{\partial L}{\partial y_t} \odot Tanh'(z_t)} \otimes {h_{t-1}}^T \tag{b}\\ \frac{\partial L}{\partial W_{xh}} & = {\color{red} {W_{hy}}^T \cdot \frac{\partial L}{\partial y_t} \odot Tanh'(z_t)} \otimes x_t^T \tag{c} \end{align} $$

It is easy to see that the red part is what dhraw represents in the code. And we can get dWxh and dWhh from formula (b) and formula (c) without dhnext. So what does dhnext mean?

@logeek404
Copy link

Stupid question : In which lines does the vanishing gradient problem manifest itself ?

It uses adaGrad optimization in the last few lines. Hence the gradient is always divided by the accumulate sum of each gradients' scaler .
mem += dparam*dparam
param += -learning_rate * dparam / np.sqrt(mem+1e-8)

However , I find that the mem and param are all local variables in the loop. Don't know if the implementation of adaGrad is correct.

@logeek404
Copy link

I have a problem in the code of line 58: dhnext = np.dot(Whh.T, dhraw). Could anyone tell me what it means?

The expression for forward propagation is:

(1)zt=Whh⋅ht−1+Wxh⋅xt+bh(2)ht=Tanh(zt)(3)yt=Why⋅ht+by(4)ot=Softmax(yt)

Here is the gradients expression of weights that I derived:

(a)∂L∂Whh=∂L∂yt⊗htT(b)∂L∂Whh=WhyT⋅∂L∂yt⊙Tanh′(zt)⊗ht−1T(c)∂L∂Wxh=WhyT⋅∂L∂yt⊙Tanh′(zt)⊗xtT

It is easy to see that the red part is what dhraw represents in the code. And we can get dWxh and dWhh from formula (b) and formula (c) without dhnext. So what does dhnext mean?
image

Since dh means the gradient of loss wrt the hidden state, there are two ways the gradients flow(back propogation). From the equation and rnn structure we learn that the hidden state feeds forward to a output and next hidden state. The dhnext represents the gradient update for current state from the next hidden state. Note that dhnext is zero at first iteration because for the last-layer (unrolled) of rnn , there 's not a gradient update flow from the next hidden state.
Hope this will help you .

@pursuemoon
Copy link

I have a problem in the code of line 58: dhnext = np.dot(Whh.T, dhraw). Could anyone tell me what it means?
The expression for forward propagation is:
(1)zt=Whh⋅ht−1+Wxh⋅xt+bh(2)ht=Tanh(zt)(3)yt=Why⋅ht+by(4)ot=Softmax(yt)
Here is the gradients expression of weights that I derived:
(a)∂L∂Whh=∂L∂yt⊗htT(b)∂L∂Whh=WhyT⋅∂L∂yt⊙Tanh′(zt)⊗ht−1T(c)∂L∂Wxh=WhyT⋅∂L∂yt⊙Tanh′(zt)⊗xtT
It is easy to see that the red part is what dhraw represents in the code. And we can get dWxh and dWhh from formula (b) and formula (c) without dhnext. So what does dhnext mean?
image

Since dh means the gradient of loss wrt the hidden state, there are two ways the gradients flow(back propogation). From the equation and rnn structure we learn that the hidden state feeds forward to a output and next hidden state. The dhnext represents the gradient update for current state from the next hidden state. Note that dhnext is zero at first iteration because for the last-layer (unrolled) of rnn , there 's not a gradient update flow from the next hidden state. Hope this will help you .

Thank you so much for your replying. I missed the partial derivative wrt the next hidden state.

@kefei-cs19
Copy link

@logeek404

However , I find that the mem and param are all local variables in the loop. Don't know if the implementation of adaGrad is correct.

mem and param point to the numpy ndarrays from the zip, and += updates their values in-place.

@Mr-Second
Copy link

Mr-Second commented Dec 11, 2023 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment