-
-
Save amankharwal/6fec6503948842515ef588e1fc36e44e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
criterion = torch.nn.MSELoss() | |
optimizer = torch.optim.SGD(model.parameters(), lr=learningRate) | |
for epoch in range(epochs): | |
# Converting inputs and labels to Variable | |
if torch.cuda.is_available(): | |
inputs = Variable(torch.from_numpy(x_train).cuda()) | |
labels = Variable(torch.from_numpy(y_train).cuda()) | |
else: | |
inputs = Variable(torch.from_numpy(x_train)) | |
labels = Variable(torch.from_numpy(y_train)) | |
# Clear gradient buffers | |
optimizer.zero_grad() | |
# get output from the model, given the inputs | |
outputs = model(inputs) | |
# get loss for the predicted output | |
loss = criterion(outputs, labels) | |
print(loss) | |
# get gradients w.r.t to parameters | |
loss.backward() | |
# update parameters | |
optimizer.step() | |
print('epoch {}, loss {}'.format(epoch, loss.item())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment