-
-
Save cosmic-cortex/3a9b68554a4aec935771ddf781b4b2f7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Function: | |
""" | |
Abstract model of a differentiable function. | |
""" | |
def __init__(self, *args, **kwargs): | |
# initializing cache for intermediate results | |
# helps with gradient calculation in some cases | |
self.cache = {} | |
# cache for gradients | |
self.grad = {} | |
def __call__(self, *args, **kwargs): | |
# calculating output | |
output = self.forward(*args, **kwargs) | |
# calculating and caching local gradients | |
self.grad = self.local_grad(*args, **kwargs) | |
return output | |
def forward(self, *args, **kwargs): | |
""" | |
Forward pass of the function. Calculates the output value and the | |
gradient at the input as well. | |
""" | |
pass | |
def backward(self, *args, **kwargs): | |
""" | |
Backward pass. Computes the global gradient at the input value | |
after forward pass. | |
""" | |
pass | |
def local_grad(self, *args, **kwargs): | |
""" | |
Calculates the local gradients of the function at the given input. | |
Returns: | |
grad: dictionary of local gradients. | |
""" | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment