Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
class Function:
Abstract model of a differentiable function.
def __init__(self, *args, **kwargs):
# initializing cache for intermediate results
# helps with gradient calculation in some cases
self.cache = {}
# cache for gradients
self.grad = {}
def __call__(self, *args, **kwargs):
# calculating output
output = self.forward(*args, **kwargs)
# calculating and caching local gradients
self.grad = self.local_grad(*args, **kwargs)
return output
def forward(self, *args, **kwargs):
Forward pass of the function. Calculates the output value and the
gradient at the input as well.
def backward(self, *args, **kwargs):
Backward pass. Computes the global gradient at the input value
after forward pass.
def local_grad(self, *args, **kwargs):
Calculates the local gradients of the function at the given input.
grad: dictionary of local gradients.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment