Skip to content

Instantly share code, notes, and snippets.

@apaszke
Created May 11, 2017 15:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save apaszke/40eef574d59c751b38568aa10227ba8d to your computer and use it in GitHub Desktop.
Save apaszke/40eef574d59c751b38568aa10227ba8d to your computer and use it in GitHub Desktop.
From baaec5d9f9e0e32bbd7d089d99698e1d83966f5d Mon Sep 17 00:00:00 2001
From: Adam Paszke <adam.paszke@gmail.com>
Date: Thu, 11 May 2017 07:36:09 -0700
Subject: [PATCH 1/2] Disable fused RNN kernels
---
test/test_nn.py | 24 ++++++++++++++----------
torch/nn/_functions/rnn.py | 7 ++++---
torch/nn/_functions/thnn/rnnFusedPointwise.py | 2 ++
3 files changed, 20 insertions(+), 13 deletions(-)
diff --git a/test/test_nn.py b/test/test_nn.py
index cfe2304..d2a649e 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -1673,16 +1673,20 @@ class TestNN(NNTestCase):
self.assertEqual(hidden1, hidden2)
def _test_rnn_retain_variables(self, dtype):
- rnn = nn.LSTM(10, 20, num_layers=2).type(dtype)
- input = Variable(torch.randn(5, 6, 10).type(dtype), requires_grad=True)
- output = rnn(input)
- output[0].sum().backward(retain_graph=True)
- grads = [input.grad.data.clone()] + [p.grad.data.clone() for p in rnn.parameters()]
- rnn.zero_grad()
- input.grad.data.zero_()
- output[0].sum().backward(retain_graph=True)
- grads2 = [input.grad.data] + [p.grad.data for p in rnn.parameters()]
- self.assertEqual(grads, grads2)
+ rnns = [nn.LSTM(10, 20, num_layers=2).type(dtype),
+ nn.GRU(10, 20, num_layers=2).type(dtype),
+ nn.RNN(10, 20, num_layers=2).type(dtype)]
+ for rnn in rnns:
+ input = Variable(torch.randn(5, 6, 10).type(dtype), requires_grad=True)
+ output = rnn(input)
+ output[0].sum().backward(retain_graph=True)
+ grads = [input.grad.data.clone()] + [p.grad.data.clone() for p in rnn.parameters()]
+ for i in range(4):
+ rnn.zero_grad()
+ input.grad.data.zero_()
+ output[0].sum().backward(retain_graph=True)
+ grads2 = [input.grad.data] + [p.grad.data for p in rnn.parameters()]
+ self.assertEqual(grads, grads2)
def test_rnn_retain_variables(self):
self._test_rnn_retain_variables(torch.DoubleTensor)
diff --git a/torch/nn/_functions/rnn.py b/torch/nn/_functions/rnn.py
index 6881a2d..85a705e 100644
--- a/torch/nn/_functions/rnn.py
+++ b/torch/nn/_functions/rnn.py
@@ -20,7 +20,8 @@ def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
- if input.is_cuda:
+ # TODO: enable fused again
+ if False and input.is_cuda:
igates = F.linear(input, w_ih)
hgates = F.linear(hidden[0], w_hh)
state = fusedBackend.LSTMFused()
@@ -43,8 +44,8 @@ def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
-
- if input.is_cuda:
+ # TODO: enable fused again
+ if False and input.is_cuda:
gi = F.linear(input, w_ih)
gh = F.linear(hidden, w_hh)
state = fusedBackend.GRUFused()
diff --git a/torch/nn/_functions/thnn/rnnFusedPointwise.py b/torch/nn/_functions/thnn/rnnFusedPointwise.py
index 19f05e5..7a789a7 100644
--- a/torch/nn/_functions/thnn/rnnFusedPointwise.py
+++ b/torch/nn/_functions/thnn/rnnFusedPointwise.py
@@ -8,6 +8,7 @@ class GRUFused(Function):
self.backend = None
def forward(self, input_gate, hidden_gate, hx, ibias=None, hbias=None):
+ raise RuntimeError("fused RNNs are disabled")
if self.backend is None:
self.backend = type2backend[type(input_gate)]
hy = input_gate.new()
@@ -46,6 +47,7 @@ class LSTMFused(Function):
self.backend = None
def forward(self, input_gate, hidden_gate, cx, ibias=None, hbias=None):
+ raise RuntimeError("fused RNNs are disabled")
if self.backend is None:
self.backend = type2backend[type(input_gate)]
hy = input_gate.new()
--
2.9.3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment