Skip to content

Instantly share code, notes, and snippets.

@Eeman1113
Created May 22, 2024 20:38
Show Gist options
  • Save Eeman1113/82a4a9785be53d9e2448fbf6077252eb to your computer and use it in GitHub Desktop.
Save Eeman1113/82a4a9785be53d9e2448fbf6077252eb to your computer and use it in GitHub Desktop.
def backward(self, gradient=None):
if not self.requires_grad:
return
if gradient is None:
if self.shape == [1]:
gradient = Tensor([1]) # dx/dx = 1 case
else:
raise RuntimeError("Gradient argument must be specified for non-scalar tensors.")
if self.grad is None:
self.grad = gradient
else:
self.grad += gradient
if self.grad_fn is not None: # not a leaf
grads = self.grad_fn.backward(gradient) # call the operation backward
for tensor, grad in zip(self.grad_fn.input, grads):
if isinstance(tensor, Tensor):
tensor.backward(grad) # recursively call the backward again for the gradient expression (chain rule)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment