Skip to content

Instantly share code, notes, and snippets.

@aviator19941
Last active December 14, 2023 00:27
Show Gist options
  • Save aviator19941/3e251322ca373dea89e316969332c9c4 to your computer and use it in GitHub Desktop.
Save aviator19941/3e251322ca373dea89e316969332c9c4 to your computer and use it in GitHub Desktop.
SymIntArrayRef reproducer
import shark_turbine.aot as aot
import torch
import torch._dynamo as dynamo
from torch._export import dynamic_dim
from torch._export.constraints import constrain_as_size, constrain_as_value
from typing import Optional
class SampleModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, inp, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
sample = torch.randn(inp.shape, generator=generator, device="cpu")
# make sure sample is on the same device as the parameters and has same dtype
sample = sample.to(device="cpu", dtype=torch.float32)
return sample
sample_model = SampleModel()
example_x = torch.rand(1, 4, 64, 64, dtype=torch.float32)
exported = aot.export(sample_model, example_x)
exported.print_readable()
compiled_binary = exported.compile(save_to=None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment