Skip to content

Instantly share code, notes, and snippets.

@iacolippo
Created August 28, 2017 10:37
Show Gist options
  • Save iacolippo/a73c96b1cd53536878fe4fb0acfcbc49 to your computer and use it in GitHub Desktop.
Save iacolippo/a73c96b1cd53536878fe4fb0acfcbc49 to your computer and use it in GitHub Desktop.
import torch.nn as nn
class SequentialSG(nn.Sequential):
def accGradParameters(self, input, gradOutput, scale=1):
currentGradOutput = gradOutput
currentModule = self.modules[-1]
for i in range(len(self.modules)-1, 0, -1):
previousModule = self.modules[i]
if currentModule.__class__.name == 'ErrorFeedback':
currentGradOutput = gradOutput
currentModule.accGradParameters(previousModule.output, currentGradOutput, scale)
currentGradOutput = current.gradInput
currentModule = previousModule
currentModule.accGradParameters(input, currentGradOutput, scale)
def backward(self, input, gradOutput, scale=1):
currentGradOutput = gradOutput
currentModule = self.modules[-1]
for i in range(len(self.modules)-1, 0, -1):
previousModule = self.modules[i]
if currentModule.__class__.name == 'ErrorFeedback':
currentGradOutput = currentModule.backward(previousModule.output, gradOutput, scale)
currentModule.gradInput = currentGradOutput
currentModule = previousModule
currentGradOutput = currentModule.backward(previousModule.output, currentGradOutput, scale)
self.gradInput = currentGradOutput
return currentGradOutput
def accUpdateGradParameters(self, input, gradOutput, lr):
currentGradOutput = gradOutput
currentModule = self.modules[-1]
for i in range(len(self.modules)-1, 0, -1):
previousModule = self.modules[i]
if currentModule.__class__.name == 'ErrorFeedback':
currentGradOutput = gradOutput
currentModule.accUpdateGradParameters(prev.output, currentGradOutput, lr)
currentGradOutput = currentModule.gradInput
currentModule = previousModule
currentModule.accUpdateGradParameters(input, currentGradOutput, lr)
def __repr__(self):
tab = ' '
line = '\n'
next = ' -> '
res = 'SequentialSG'
res = res + ' {' + line + tab + '[input'
for i in range(len(self.modules)):
res = res + next + '(' + str(i) + ')'
res = res + next + 'output]'
for i in range(len(self.modules)):
res = res + line + tab + '(' + str(i) + '): ' + str(self.modules[i]).replace(line, line + tab)
res = res + line + '}'
return res
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment