Skip to content

Instantly share code, notes, and snippets.

@mcminis1
Created January 17, 2023 21:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mcminis1/5ed44fb56891e0388fe47e52f8f12d66 to your computer and use it in GitHub Desktop.
Save mcminis1/5ed44fb56891e0388fe47e52f8f12d66 to your computer and use it in GitHub Desktop.
RNN using pytorch
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.nn import RNNCell, Linear, Sequential, HuberLoss, Flatten, Module
from torch.utils.data import DataLoader
# Configuration
## RNN definition
### maximum length of sequence
T = 6
### hidden state vector dimension
hidden_dim = 32
### output length
output_dim = 8
## training params
### cutoff for linear gradient
alpha = 0.025
### learning rate
eps = 1e-1
### number of training epochs
n_epochs = 20000
### number of samples to reserve for test
test_set_size = 4
### number of samples to generate
n_samples = 128
def plot_tests(step):
v_lines = []
y_hats = predict_test_set(model)
for plot_i, (y_hat, ys) in enumerate(zip(y_hats.cpu().numpy(), y_test.cpu().numpy())):
x = x_grid[:output_dim] + dx_grid*(output_dim + 1)*plot_i
v_lines.append(dx_grid*(output_dim + 1)*plot_i - dx_grid)
plt.plot(x, y_hat, "r")
plt.plot(x, ys, "g")
for x_pos in v_lines[1:]:
plt.vlines(x_pos, -1, 1)
frame1 = plt.gca()
frame1.axes.get_xaxis().set_ticks([])
frame1.set_ylim([-1.1,1.1])
plt.savefig(f"step_plots/{step:06d}.png", format='png')
plt.clf()
# Use this to create gif from pngs: convert -delay 10 -loop 1 step_plots/*.png rnn_optimization_pytorch.gif
# set up training data:
# let's use sin as out target method.
x_grid = np.linspace(0, 4 * np.pi, num=n_samples + test_set_size + T + output_dim)
dx_grid = x_grid[1] - x_grid[0]
sin_wave = np.sin(x_grid)
n_data_points = sin_wave.shape[0]
n_samples = n_data_points - T - output_dim
raw_X = []
for i in range(0, n_samples):
raw_X.append(sin_wave[i : i + T + output_dim])
np.random.shuffle(raw_X)
X_test = raw_X[:test_set_size]
raw_X = raw_X[test_set_size:]
def get_tensors(l):
X = []
y = []
for x_y in l:
X.append(x_y[:T])
y.append(x_y[T:])
X = torch.FloatTensor(X)
y = torch.FloatTensor(y)
return X, y
X, y = get_tensors(raw_X)
X_test, y_test = get_tensors(X_test)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
# Define model
class NeuralNetwork(Module):
def __init__(self):
super().__init__()
self.flatten = Flatten()
self.rnn_stack = Sequential(
RNNCell(T, hidden_dim, bias=False, nonlinearity='tanh'),
Linear(hidden_dim, output_dim, bias=False)
)
def forward(self, x):
x = self.flatten(x)
logits = self.rnn_stack(x)
return logits
model = NeuralNetwork().to(device)
loss_fn = HuberLoss(delta=alpha)
optimizer = torch.optim.SGD(model.parameters(), lr=eps)
def predict_test_set(model):
model.eval()
with torch.no_grad():
X, y = X_test.to(device), y_test.to(device)
pred = model(X)
return pred
def train(step, X, y):
model.train()
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 100 == 0:
loss = loss.item()
print(f"loss: {loss:>7f}")
plot_tests(step)
for step in range(n_epochs):
train(step, X, y)
plot_tests(step)
print("Done!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment