Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created March 26, 2023 17:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ezyang/2024bfdb6a2161f65ad7820264057fe7 to your computer and use it in GitHub Desktop.
Save ezyang/2024bfdb6a2161f65ad7820264057fe7 to your computer and use it in GitHub Desktop.
import torch
import torchvision
import torch._dynamo
torch._dynamo.config.assume_static_by_default = True
device = "cuda:0"
model = torchvision.models.detection.maskrcnn_resnet50_fpn()
model = model.to(device)
model = torch.compile(model, dynamic=True)
model.eval()
# This works
for _ in range(1):
x = torch.rand(1, 3, 800, 800, device=device)
r = model(x)
r[0]['boxes'].sum().backward()
#next(iter(r[0].items()))[0].sum().backward()
# This does not
#with torch.no_grad():
# _ = model(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment