Skip to content

Instantly share code, notes, and snippets.

@Lyken17
Created August 14, 2023 09:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Lyken17/d321de4d88c6752241923dd925e6d43e to your computer and use it in GitHub Desktop.
Save Lyken17/d321de4d88c6752241923dd925e6d43e to your computer and use it in GitHub Desktop.
import torch
from torch import nn
# net = nn.Linear(500, 500)
# input = torch.randn(64, 500)
net = nn.Conv2d(3, 3, kernel_size=3, padding=1)
input = torch.randn(1, 3, 32, 32)
# only calculate input grad, prints ('_saved_mat2', torch.Size([500, 500]))
mode = 1
if mode == 1:
# no activation saved
print("calcuate dy/dx only")
input.requires_grad_(True)
for p in net.parameters():
p.requires_grad_(False)
elif mode == 2:
# activation is saved
print("calcuate dy/dw only")
input.requires_grad_(False)
for p in net.parameters():
p.requires_grad_(True)
print("compiling")
net = torch.compile(net)
output = input
for i in range(1):
output = net(output)
for x in dir(output.grad_fn):
if x.startswith("_saved"):
data = getattr(output.grad_fn, x)
if isinstance(data, torch.Tensor):
print((x, data.shape))
```
calcuate dy/dx only
compiling
('_saved_input', torch.Size([1, 3, 32, 32]))
('_saved_weight', torch.Size([3, 3, 3, 3]))
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment