Skip to content

Instantly share code, notes, and snippets.

@cosmic-cortex
Last active October 23, 2019 07:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cosmic-cortex/3a9b68554a4aec935771ddf781b4b2f7 to your computer and use it in GitHub Desktop.
Save cosmic-cortex/3a9b68554a4aec935771ddf781b4b2f7 to your computer and use it in GitHub Desktop.
class Function:
"""
Abstract model of a differentiable function.
"""
def __init__(self, *args, **kwargs):
# initializing cache for intermediate results
# helps with gradient calculation in some cases
self.cache = {}
# cache for gradients
self.grad = {}
def __call__(self, *args, **kwargs):
# calculating output
output = self.forward(*args, **kwargs)
# calculating and caching local gradients
self.grad = self.local_grad(*args, **kwargs)
return output
def forward(self, *args, **kwargs):
"""
Forward pass of the function. Calculates the output value and the
gradient at the input as well.
"""
pass
def backward(self, *args, **kwargs):
"""
Backward pass. Computes the global gradient at the input value
after forward pass.
"""
pass
def local_grad(self, *args, **kwargs):
"""
Calculates the local gradients of the function at the given input.
Returns:
grad: dictionary of local gradients.
"""
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment