Skip to content

Instantly share code, notes, and snippets.

@HaraldKorneliussen
Created January 4, 2018 12:21
Show Gist options
  • Save HaraldKorneliussen/aa1264019e47258b7a948e5cb02c566c to your computer and use it in GitHub Desktop.
Save HaraldKorneliussen/aa1264019e47258b7a948e5cb02c566c to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
BATCH_SIZE = 1
INPUT_DIM = 1
OUTPUT_DIM = 1
DTYPE = np.float32
class Net(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers):
super(Net, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.hidden_layers = hidden_layers
self.rnn = nn.RNN(input_dim, hidden_dim, hidden_layers)
self.h2o = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
h_0 = Variable(torch.zeros(self.hidden_layers, BATCH_SIZE, self.hidden_dim)).cuda()
output, h_t = self.rnn(x, h_0)
output = self.h2o(output)
return output
def weights_init(m):
if isinstance(m, nn.RNN):
nn.init.xavier_uniform(m.weight_ih_l0.data)
nn.init.orthogonal(m.weight_hh_l0.data)
nn.init.constant(m.bias_ih_l0.data, 0)
nn.init.constant(m.bias_hh_l0.data, 0)
if isinstance(m, nn.Linear):
nn.init.xavier_uniform(m.weight.data)
nn.init.constant(m.bias.data, 0)
data = np.loadtxt('data/mg17.csv', delimiter=',', dtype=DTYPE)
trX = torch.from_numpy(np.expand_dims(data[:4000, [0]], axis=1)).cuda()
trY = torch.from_numpy(np.expand_dims(data[:4000, [1]], axis=1)).cuda()
loss_fcn = nn.MSELoss()
model = Net(INPUT_DIM, 10, OUTPUT_DIM, 1).cuda()
if DTYPE == np.float32:
model.float()
else:
model.double()
model.apply(weights_init)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for e in range(500):
model.train()
x = Variable(trX)
y = Variable(trY)
model.zero_grad()
output = model.forward(x)
loss = loss_fcn(output, y)
loss.backward()
optimizer.step()
print("Epoch", e + 1, "TR:%e" % loss.cpu().data.numpy()[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment