Created
February 7, 2018 06:30
-
-
Save csarofeen/3ebb607d2a7c2c1093a134e356b9e867 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
import itertools | |
seq_size = 16 | |
batch_size = 32 | |
inp_size = 64 | |
hidden_size = 128 | |
num_layers_l = [1, 2, 3, 4] | |
bias_l = [True, False] | |
batch_first_l = [False]#, True] | |
dropout = 0.0 | |
bidirectional_l = [True, False] | |
for config in itertools.product( | |
num_layers_l, bias_l, batch_first_l, bidirectional_l): | |
num_layers, bias, batch_first, bidirectional = config | |
print("Testing config ", config) | |
torch.manual_seed(42) | |
torch.cuda.manual_seed(42) | |
torch.backends.cudnn.enabled=True | |
if batch_first: | |
cudnn_inp = torch.cuda.FloatTensor(batch_size, seq_size, inp_size).uniform_() | |
else: | |
cudnn_inp = torch.cuda.FloatTensor(seq_size, batch_size, inp_size).uniform_() | |
cudnn_inp = Variable(cudnn_inp, requires_grad = True) | |
cudnn_lstm = nn.LSTM(inp_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional).cuda() | |
cudnn_out, (cudnn_hx, cudnn_cx) = cudnn_lstm(cudnn_inp) | |
cudnn_out.sum().backward() | |
torch.manual_seed(42) | |
torch.cuda.manual_seed(42) | |
torch.backends.cudnn.enabled=False | |
pyt_inp = torch.cuda.FloatTensor(seq_size, batch_size, inp_size).uniform_() | |
pyt_inp = Variable(pyt_inp, requires_grad = True) | |
pyt_lstm = nn.LSTM(inp_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional).cuda() | |
pyt_out, (pyt_hx, pyt_cx) = pyt_lstm(pyt_inp) | |
pyt_out.sum().backward() | |
pyt_vars = [pyt_out, pyt_inp.grad, pyt_hx, pyt_cx] | |
cudnn_vars = [cudnn_out, cudnn_inp.grad, cudnn_hx, cudnn_cx] | |
pyt_tens = [ten.data for ten in pyt_vars] | |
cudnn_tens = [ten.data for ten in cudnn_vars] | |
for pyt_ten, cudnn_ten in zip(pyt_tens, cudnn_tens): | |
err = (pyt_ten - cudnn_ten).abs().max() | |
if err > 2e-6: | |
print("Found an error of ", err) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment