Skip to content

Instantly share code, notes, and snippets.

@mcminis1
Last active January 17, 2023 21:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mcminis1/33238787522303ab12bf036dd8cd1501 to your computer and use it in GitHub Desktop.
Save mcminis1/33238787522303ab12bf036dd8cd1501 to your computer and use it in GitHub Desktop.
RNN from scratch
import numpy as np
import matplotlib.pyplot as plt
import pickle
# Configuration
## RNN definition
### maximum length of sequence
T = 6
### hidden state vector dimension
hidden_dim = 32
### output length
output_dim = 8
## training params
### cutoff for linear gradient
alpha = 0.025
### learning rate
eps = 1e-1
### number of training epochs
n_epochs = 10000
### number of samples to reserve for test
test_set_size = 4
### number of samples to generate
n_samples = 50
rng = np.random.default_rng(2882)
# the "hidden layer". aka the transition matrix. these are the weights in the RNN.
# shape: hidden_dim x hidden_dim
W = rng.normal(0, (hidden_dim * hidden_dim) ** -0.75, size=(hidden_dim, hidden_dim))
_, W = np.linalg.qr(W, mode='complete')
# input matrix. translates from input vector to W
# shape: hidden_dim x T
U = rng.normal(0, (hidden_dim * T) ** -0.75, size=(hidden_dim, T))
svd_u, _, svd_vh = np.linalg.svd(U, full_matrices=False)
U = np.dot(svd_u, svd_vh)
# output matrix. translates from W to the output vector
# shape: output_dim x hidden_dim
V = rng.normal(0, (output_dim * hidden_dim) ** -0.75, size=(output_dim, hidden_dim))
svd_u, _, svd_vh = np.linalg.svd(V, full_matrices=False)
V = np.dot(svd_u, svd_vh)
# this is the formula used to update the hidden state
def new_hidden_state(x, s):
u = np.dot(U, x)
w = np.dot(W, s)
rv = 1 / (1 + np.exp(-(u + w)))
return rv
def el_mul(v, m):
r = np.zeros_like(m)
for c in range(r.shape[1]):
r[:, c] = v * m[:, c]
return r
def l_grad(dy):
return np.array([np.maximum(np.minimum(1.0,y),-1.0) for y in dy])
def plot_tests(step):
v_lines = []
for plot_i, x_y in enumerate(X_test):
xs = x_y[:T]
ys = x_y[T:]
rnn_s = np.zeros(hidden_dim, dtype=np.float64)
for t in range(T):
x_i = np.zeros(T, dtype=np.float64)
x_i[t] = xs[t]
rnn_s = new_hidden_state(x_i, rnn_s)
y_hat = np.dot(V, rnn_s)
x = x_grid[:output_dim] + dx_grid*(output_dim + 1)*plot_i
v_lines.append(dx_grid*(output_dim + 1)*plot_i - dx_grid)
plt.plot(x, y_hat, "r")
plt.plot(x, ys, "g")
for x_pos in v_lines[1:]:
plt.vlines(x_pos, -1, 1)
frame1 = plt.gca()
frame1.axes.get_xaxis().set_ticks([])
frame1.set_ylim([-1.1,1.1])
plt.savefig(f"step_plots/{step:06d}.png", format='png')
plt.clf()
# set up training data:
# let's use sin as out target method.
x_grid = np.linspace(0, 4 * np.pi, num=n_samples + test_set_size + T + output_dim)
dx_grid = x_grid[1] - x_grid[0]
sin_wave = np.sin(x_grid)
n_data_points = sin_wave.shape[0]
n_samples = n_data_points - T - output_dim
X = []
for i in range(0, n_samples):
X.append(sin_wave[i : i + T + output_dim])
np.random.shuffle(X)
X_test = X[:test_set_size]
X = X[test_set_size:]
print(f"n_data_points: {n_data_points}")
print(f"n_samples: {len(X)}")
print(f"n_test: {len(X_test)}")
print(f"input length : {T}")
print(f"hidden_dim length: {hidden_dim}")
print(f"output length: {output_dim}")
eps = eps / n_samples
for e_i in range(n_epochs):
loss = 0
dL_dV = 0
dL_dU = 0
dL_dW = 0
for x_y in X:
xs = x_y[:T]
ys = x_y[T:]
rnn_s = np.zeros(hidden_dim, dtype=np.float64)
rnn_ds_dU = np.zeros((hidden_dim, T), dtype=np.float64)
rnn_ds_dW = np.zeros((hidden_dim, hidden_dim), dtype=np.float64)
for t in range(T):
x_i = np.zeros(T, dtype=np.float64)
x_i[t] = xs[t]
p_rnn_s = rnn_s
rnn_s = new_hidden_state(x_i, rnn_s)
# derivs
ds = rnn_s * (1 - rnn_s)
ds_W = el_mul(ds, W)
rnn_ds_dU = np.dot(ds_W, rnn_ds_dU)
rnn_ds_dU += np.outer(ds, x_i)
rnn_ds_dW = np.dot(ds_W, rnn_ds_dW)
rnn_ds_dW += np.outer(ds, p_rnn_s)
dy = np.dot(V, rnn_s) - ys
rnn_dL_dV = np.outer(l_grad(dy), rnn_s)
dyV = np.dot(l_grad(dy), V)
loss_i = (0.5 * dy**2).sum()
rnn_dL_dW = el_mul(dyV, rnn_ds_dW)
rnn_dL_dU = el_mul(dyV, rnn_ds_dU)
loss += loss_i
dL_dV += rnn_dL_dV
dL_dW += rnn_dL_dW
dL_dU += rnn_dL_dU
if (e_i + 1) % 100 == 0 or e_i == 0:
print(
f"{e_i}: total loss: {loss}\n\t\t<error> per data point: {np.sqrt(loss/n_samples/output_dim)}"
)
print(f" dV range: {np.max(dL_dV) - np.min(dL_dV)}")
print(f" dU range: {np.max(dL_dU) - np.min(dL_dU)}")
print(f" dW range: {np.max(dL_dW) - np.min(dL_dW)}")
plot_tests(e_i)
W = W - eps * dL_dW
V = V - eps * dL_dV
U = U - eps * dL_dU
with open('weights.pkl', 'wb')as f:
pickle.dump([U,W,V], f, protocol=pickle.HIGHEST_PROTOCOL)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment