Skip to content

Instantly share code, notes, and snippets.

@yaroslavvb2
Last active October 22, 2017 19:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yaroslavvb2/2d92df19af84298c87416fc6b510d88a to your computer and use it in GitHub Desktop.
Save yaroslavvb2/2d92df19af84298c87416fc6b510d88a to your computer and use it in GitHub Desktop.
pytorch custom matmul backprop
def backward(ctx, grad_output):
matrix1, matrix2 = ctx.saved_variables
grad_matrix1 = grad_matrix2 = None
if mode == 'capture':
Bs.insert(0, grad_output.data)
As.insert(0, matrix2.data)
elif mode == 'kfac':
B = grad_output.data
A = matrix2.data
kfac_A = As_inv.pop() @ A
kfac_B = Bs_inv.pop() @ B
grad_matrix1 = torch.mm(kfac_B, kfac_A.t())
elif mode == 'standard':
grad_matrix1 = torch.mm(grad_output, matrix2.t())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment