Last active
May 1, 2018 17:12
-
-
Save johannah/407492a7adfa4d52c8e5ca6da470a9b4 to your computer and use it in GitHub Desktop.
learning rnn sinewaves
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# from KK | |
import torch | |
from torch.autograd import Variable | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch.nn.init as init | |
from IPython import embed | |
dtype = torch.FloatTensor | |
input_size, hidden_size, output_size = 1,128,1 | |
epochs = 300 | |
seq_length = 20 | |
lr = 1e-2 | |
# make sine wave data | |
data_time_steps = np.linspace(2,10, seq_length+1) | |
data = np.sin(data_time_steps) | |
data.resize((seq_length+1), 1) | |
batch_size = 10 | |
batch_data = np.array([data for d in range(batch_size)]).transpose(1,0,2) | |
# target is input data shifted by one time step | |
x = Variable(torch.FloatTensor(batch_data[:-1]), requires_grad=False) | |
y = Variable(torch.FloatTensor(batch_data[1:]), requires_grad=False) | |
# weights are initialized using a normal distribution with zero mean | |
# w_inp is for input and hidden connections | |
# will transform input to hidden | |
w_inp = Variable(torch.FloatTensor(input_size, hidden_size), requires_grad=True) | |
init.normal(w_inp,0.0,0.05) | |
# w2 for hidden to output connections | |
w_out = Variable(torch.FloatTensor(hidden_size, output_size), requires_grad=True) | |
init.normal(w_out,0.0,0.05) | |
h_init = Variable(torch.FloatTensor(np.zeros((batch_size, hidden_size))), requires_grad=False) | |
def recurrent_fn(x_t, h_tm1): | |
# project x onto hidden | |
proj_x = x_t.mm(w_inp) | |
h_t = torch.tanh(proj_x+h_tm1) | |
return h_t | |
def one_pass(x,y): | |
outputs = [] | |
h_tm1 = h_init | |
for i in range(len(x)): | |
h_t = recurrent_fn(x[i], h_tm1) | |
h_tm1 = h_t | |
outputs.append(h_t[None]) | |
outputs = torch.cat(outputs) | |
y_pred = outputs.matmul(w_out) | |
mse_loss = ((y_pred-y)**2).mean() | |
return mse_loss, y_pred | |
for e in range(20000): | |
mse_loss,y_pred = one_pass(x,y) | |
if not e%100: | |
print('starting epoch {} loss {}'.format(e,mse_loss.data[0])) | |
mse_loss.backward() | |
w_inp.data -= lr*w_inp.grad.data | |
w_out.data -= lr*w_out.grad.data | |
w_inp.grad.data.zero_() | |
w_out.grad.data.zero_() | |
plt.plot(y_pred.data.numpy()[:,0], label='ypred') | |
plt.plot(y.data.numpy()[:,0], label='y') | |
plt.legend() | |
plt.show() | |
embed() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment