Skip to content

Instantly share code, notes, and snippets.

@johannah
Last active May 1, 2018 17:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save johannah/407492a7adfa4d52c8e5ca6da470a9b4 to your computer and use it in GitHub Desktop.
Save johannah/407492a7adfa4d52c8e5ca6da470a9b4 to your computer and use it in GitHub Desktop.
learning rnn sinewaves
# from KK
import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.init as init
from IPython import embed
dtype = torch.FloatTensor
input_size, hidden_size, output_size = 1,128,1
epochs = 300
seq_length = 20
lr = 1e-2
# make sine wave data
data_time_steps = np.linspace(2,10, seq_length+1)
data = np.sin(data_time_steps)
data.resize((seq_length+1), 1)
batch_size = 10
batch_data = np.array([data for d in range(batch_size)]).transpose(1,0,2)
# target is input data shifted by one time step
x = Variable(torch.FloatTensor(batch_data[:-1]), requires_grad=False)
y = Variable(torch.FloatTensor(batch_data[1:]), requires_grad=False)
# weights are initialized using a normal distribution with zero mean
# w_inp is for input and hidden connections
# will transform input to hidden
w_inp = Variable(torch.FloatTensor(input_size, hidden_size), requires_grad=True)
init.normal(w_inp,0.0,0.05)
# w2 for hidden to output connections
w_out = Variable(torch.FloatTensor(hidden_size, output_size), requires_grad=True)
init.normal(w_out,0.0,0.05)
h_init = Variable(torch.FloatTensor(np.zeros((batch_size, hidden_size))), requires_grad=False)
def recurrent_fn(x_t, h_tm1):
# project x onto hidden
proj_x = x_t.mm(w_inp)
h_t = torch.tanh(proj_x+h_tm1)
return h_t
def one_pass(x,y):
outputs = []
h_tm1 = h_init
for i in range(len(x)):
h_t = recurrent_fn(x[i], h_tm1)
h_tm1 = h_t
outputs.append(h_t[None])
outputs = torch.cat(outputs)
y_pred = outputs.matmul(w_out)
mse_loss = ((y_pred-y)**2).mean()
return mse_loss, y_pred
for e in range(20000):
mse_loss,y_pred = one_pass(x,y)
if not e%100:
print('starting epoch {} loss {}'.format(e,mse_loss.data[0]))
mse_loss.backward()
w_inp.data -= lr*w_inp.grad.data
w_out.data -= lr*w_out.grad.data
w_inp.grad.data.zero_()
w_out.grad.data.zero_()
plt.plot(y_pred.data.numpy()[:,0], label='ypred')
plt.plot(y.data.numpy()[:,0], label='y')
plt.legend()
plt.show()
embed()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment