Skip to content

Instantly share code, notes, and snippets.

@junjihashimoto
Created March 22, 2023 10:25
Show Gist options
  • Save junjihashimoto/76f73f118289a5c33f4e311b66a7677a to your computer and use it in GitHub Desktop.
Save junjihashimoto/76f73f118289a5c33f4e311b66a7677a to your computer and use it in GitHub Desktop.
import torch
@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(x)
return a + b
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(100, 10)
def forward(self, x):
return torch.nn.functional.relu(self.lin(x))
mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
return (
torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
torch.randint(1000, (b,)).cuda(),
)
N_ITERS = 10
from torchvision.models import resnet18
def init_model():
return resnet18().to(torch.float32).cuda()
def evaluate(mod, inp):
return mod(inp)
model = init_model()
# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()
evaluate_opt = torch.compile(evaluate, mode="reduce-overhead")
inp = generate_data(16)[0]
print("eager:", timed(lambda: evaluate(model, inp))[1])
print("compile:", timed(lambda: evaluate_opt(model, inp))[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment