Skip to content

Instantly share code, notes, and snippets.

@soulitzer
Created November 16, 2023 06:49
Show Gist options
  • Save soulitzer/57b613a5de007fbce4af3ba85e5dcfea to your computer and use it in GitHub Desktop.
Save soulitzer/57b613a5de007fbce4af3ba85e5dcfea to your computer and use it in GitHub Desktop.
from torch.nested._internal.nested_tensor import jagged_from_list
a = torch.randn(2, 7, 256, requires_grad=True, dtype=torch.float32)
b = torch.randn(3, 7, 256, requires_grad=True, dtype=torch.float32)
c = torch.randn(4, 7, 256, requires_grad=True, dtype=torch.float32)
d = torch.randn(5, 7, 256, requires_grad=True, dtype=torch.float32)
nt1 = jagged_from_list([a, b, c, d], None)[0]
nt2 = jagged_from_list([a, b, c, d], None)[0]
nt1_view = nt1.select(2, 1)
nt2_view = nt2.select(2, 1)
def fn(x, y):
# guard on the stride
if x.stride()[0] == y.stride()[0]:
pass
return x.clone()
fn_compiled = torch.compile(fn, backend="aot_eager")
out = fn_compiled(nt1_view, nt2_view)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment