Skip to content

Instantly share code, notes, and snippets.

Last active April 1, 2023 18:51
Show Gist options
  • Save metric-space/f95eb8f2ead9c93f0c76ff52be490c40 to your computer and use it in GitHub Desktop.
Save metric-space/f95eb8f2ead9c93f0c76ff52be490c40 to your computer and use it in GitHub Desktop.
Karpathy's charRNN using primitive jax
import jax.numpy as jnp
from jax import jit, vmap, grad, value_and_grad
from jax import random
import jax
SEED = 42234
key = random.PRNGKey(SEED)
# hyperparameters
hidden_size = 100
seq_length = 25
learning_rate = 1e-3
def initialize_network_params(hidden_size, vocab_size, key):
key, *subkey = random.split(key, num=4)
# model parameters
# :: Matrix R[hidden_size] C[vocab_size]
Wxh = random.normal(subkey[0], (hidden_size, vocab_size)) * 0.01 # input to hidden
Whh = random.normal(subkey[1], (hidden_size, hidden_size)) * 0.01 # hidden to hidden
Why = random.normal(subkey[2], (vocab_size, hidden_size)) * 0.01 # hidden to output
bh = jnp.zeros((hidden_size, 1)) # hidden bias
by = jnp.zeros((vocab_size, 1)) # output bias
return Wxh, Whh, Why, bh, by
def loss(params,inputs, targets, hprev):
loss = 0
Wxh, Whh, Why, bh, by = params
hprev = hprev.copy()
for t in range(len(inputs)):
x = jnp.zeros(by.shape)
x =[inputs[t]].set(1)
y = targets[t]
hprev = jnp.tanh(Wxh @ x + Whh @ hprev + bh)
y_pred = Why @ hprev + by
log_prob = jax.nn.log_softmax(y_pred.flatten()) # this (log_softmax) affects stability
loss += -log_prob[y]
return loss, hprev
def sample(params, hprev, seed_ix, n, key):
Wxh, Whh, Why, bh, by = params
x = jnp.zeros((vocab_size,1))
x =[seed_ix].set(1)
ixes = []
key_ = key
h = hprev.copy()
for i in range(n):
key_, subkey = random.split(key_)
h = jnp.tanh(Wxh @ x + Whh @ h + bh)
y_pred = Why @ h + by
p = jax.nn.softmax(y_pred.flatten())
ix = jax.random.choice(subkey, jnp.arange(vocab_size), p=p, replace=False)
ix = int(ix)
x = jnp.zeros((vocab_size,1))
x =[ix].set(1)
return ixes, key_
def update(params, inputs, targets, hprev):
(loss_, hprev),grads = value_and_grad(loss, has_aux=True)(params, inputs, targets, hprev)
return [jnp.clip(w - learning_rate * dw, -5, 5) for (w, dw) in zip(params, grads)], loss_, hprev
with open('input.txt', 'r') as f:
data =
chars = list(set(data))
char_to_ix = { ch:i for i,ch in enumerate(chars)}
ix_to_char = { i:ch for i, ch in enumerate(chars)}
vocab_size = len(chars)
params = initialize_network_params(hidden_size, vocab_size, key)
p = -seq_length
n = 0
hprev = jnp.zeros((hidden_size,1))
Wxh, Whh, Why, bh, by = params
while True:
p = p + seq_length
if ((p + seq_length + 1) >= len(data)) or n == 0:
p = 0
hprev = jnp.zeros((hidden_size,1))
inputs = [char_to_ix[ch] for ch in data[p:p+seq_length]]
targets = [char_to_ix[ch] for ch in data[p+1:p+seq_length+1]]
if n % 1000 == 0:
key, subkey = random.split(key)
sample_ix, key = sample(params, hprev, inputs[0], 200, subkey)
txt = ''.join(ix_to_char[ix] for ix in sample_ix)
print('----\n %s \n----' % (txt, ))
params,loss_, hprev = update(params, inputs, targets, hprev)
if n % 100 == 0: print('iter %d, loss: %f' % (n, loss_))
n += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment