Skip to content

Instantly share code, notes, and snippets.

/-

Created September 30, 2017 20:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anonymous/49c10bc17ac4a97307d52c07d01a2870 to your computer and use it in GitHub Desktop.
Save anonymous/49c10bc17ac4a97307d52c07d01a2870 to your computer and use it in GitHub Desktop.
diff --git a/torch/autograd/_functions/blas.py b/torch/autograd/_functions/blas.py
index 787ac3b..f54fc1c 100644
--- a/torch/autograd/_functions/blas.py
+++ b/torch/autograd/_functions/blas.py
@@ -13,6 +13,27 @@ def _get_output(ctx, arg, inplace=False):
return arg.new().resize_as_(arg)
+class Mm(InplaceFunction):
+
+ @staticmethod
+ def forward(ctx, matrix1, matrix2):
+ ctx.save_for_backward(matrix1, matrix2)
+ return torch.mm(matrix1, matrix2)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ matrix1, matrix2 = ctx.saved_variables
+ grad_matrix1 = grad_matrix2 = None
+
+ if ctx.needs_input_grad[0]:
+ grad_matrix1 = torch.mm(grad_output, matrix2.t())
+
+ if ctx.needs_input_grad[1]:
+ grad_matrix2 = torch.mm(matrix1.t(), grad_output)
+
+ return grad_matrix1, grad_matrix2
+
+
class Addmm(InplaceFunction):
@staticmethod
diff --git a/torch/autograd/variable.py b/torch/autograd/variable.py
index cd3b0b7..a803e04 100644
--- a/torch/autograd/variable.py
+++ b/torch/autograd/variable.py
@@ -583,8 +583,12 @@ class Variable(_C._VariableBase):
return self._static_blas(cls, (self,) + args, inplace)
def mm(self, matrix):
- output = Variable(self.data.new(self.data.size(0), matrix.data.size(1)))
- return Addmm.apply(output, self, matrix, 0, 1, True)
+ if self.data.is_sparse:
+ assert matrix.data.is_sparse is False
+ return Mm.apply(self, matrix)
+ else:
+ output = Variable(self.data.new(self.data.size(0), matrix.data.size(1)))
+ return Addmm.apply(output, self, matrix, 0, 1, True)
def bmm(self, batch):
output = Variable(self.data.new(self.data.size(0), self.data.size(1),
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment