Skip to content

Instantly share code, notes, and snippets.

@iacolippo
Created August 28, 2017 11:52
Show Gist options
  • Save iacolippo/91486cd95d77e705f61894ae162b267a to your computer and use it in GitHub Desktop.
Save iacolippo/91486cd95d77e705f61894ae162b267a to your computer and use it in GitHub Desktop.
from FunctionErrorFeedback import ErrorFeedbackFunction
from torch.autograd import Function, Variable
import torch
import torch.nn as nn
class EF(nn.Module):
def __init__(self, layer_dim, error_dim):
super(EF, self).__init__()
self.feedback = torch.Tensor(error_dim, layer_dim)
self.reset_variables()
def reset_variables(self):
# brutal initialization
self.feedback.uniform_(-0.1, 0.1)
def forward(self, input):
return ErrorFeedbackFunction(input, self.feedback)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment