Skip to content

Instantly share code, notes, and snippets.

@iacolippo
Created August 28, 2017 10:43
Show Gist options
  • Save iacolippo/e3980897940da1ff4a85f870366a5c8a to your computer and use it in GitHub Desktop.
Save iacolippo/e3980897940da1ff4a85f870366a5c8a to your computer and use it in GitHub Desktop.
import torch
from torch.autograd import Function, Variable
class ErrorFeedbackFunction(Function):
# the forward pass consists in copying the input to the output
@staticmethod
def forward(ctx, input, feedback):
ctx.save_for_backward(input, feedback)
return input
@staticmethod
def backward(ctx, grad_output):
# grad_feedback is None and grad_input is the random projection
# of the error
input, feedback = ctx.saved_variables
grad_input = grad_feedback = None
if ctx.needs_input_grad[0]:
# random projection
grad_input = torch.mm(grad_output, feedback)
return grad_input, grad_feedback
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment