Skip to content

Instantly share code, notes, and snippets.

@zou3519
Created June 28, 2021 21:14
Show Gist options
  • Save zou3519/742b36c9eabb41449f049e6bc3ab1410 to your computer and use it in GitHub Desktop.
Save zou3519/742b36c9eabb41449f049e6bc3ab1410 to your computer and use it in GitHub Desktop.
import torch
from functorch import grad
device='cpu'
x = torch.randn([], requires_grad=True)
a = torch.tensor([1], device=device)
b = torch.tensor([1, 2, 3], device=device)
c = torch.tensor([1, 2], device=device)
def foo(x):
y = torch.tensor(list(product([a], b, c)), device=device)
return (x + y).sum()
grad(foo)(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment