Skip to content

Instantly share code, notes, and snippets.

@csarofeen
Created February 7, 2018 06:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save csarofeen/3ebb607d2a7c2c1093a134e356b9e867 to your computer and use it in GitHub Desktop.
Save csarofeen/3ebb607d2a7c2c1093a134e356b9e867 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.autograd import Variable
import itertools
seq_size = 16
batch_size = 32
inp_size = 64
hidden_size = 128
num_layers_l = [1, 2, 3, 4]
bias_l = [True, False]
batch_first_l = [False]#, True]
dropout = 0.0
bidirectional_l = [True, False]
for config in itertools.product(
num_layers_l, bias_l, batch_first_l, bidirectional_l):
num_layers, bias, batch_first, bidirectional = config
print("Testing config ", config)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.enabled=True
if batch_first:
cudnn_inp = torch.cuda.FloatTensor(batch_size, seq_size, inp_size).uniform_()
else:
cudnn_inp = torch.cuda.FloatTensor(seq_size, batch_size, inp_size).uniform_()
cudnn_inp = Variable(cudnn_inp, requires_grad = True)
cudnn_lstm = nn.LSTM(inp_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional).cuda()
cudnn_out, (cudnn_hx, cudnn_cx) = cudnn_lstm(cudnn_inp)
cudnn_out.sum().backward()
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.enabled=False
pyt_inp = torch.cuda.FloatTensor(seq_size, batch_size, inp_size).uniform_()
pyt_inp = Variable(pyt_inp, requires_grad = True)
pyt_lstm = nn.LSTM(inp_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional).cuda()
pyt_out, (pyt_hx, pyt_cx) = pyt_lstm(pyt_inp)
pyt_out.sum().backward()
pyt_vars = [pyt_out, pyt_inp.grad, pyt_hx, pyt_cx]
cudnn_vars = [cudnn_out, cudnn_inp.grad, cudnn_hx, cudnn_cx]
pyt_tens = [ten.data for ten in pyt_vars]
cudnn_tens = [ten.data for ten in cudnn_vars]
for pyt_ten, cudnn_ten in zip(pyt_tens, cudnn_tens):
err = (pyt_ten - cudnn_ten).abs().max()
if err > 2e-6:
print("Found an error of ", err)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment