Skip to content

Instantly share code, notes, and snippets.

@soulitzer
Created July 17, 2023 15:24
Show Gist options
  • Save soulitzer/d1c8960030e85668f58157efe6e55b86 to your computer and use it in GitHub Desktop.
Save soulitzer/d1c8960030e85668f58157efe6e55b86 to your computer and use it in GitHub Desktop.
NestedTensor python torch dispatch wrapper subclass
import torch
from torch.library import Library
from typing import List
from torch.utils._pytree import tree_map
import functools
# NestedTensor __torch_dispatch__ wrapper tensor subclass POC
#
# 1) The __torch_dispatch__ handles pointwise ops entirely in python
# 2) Falls back to cpp NT kernels by converts args to cpp NT and then redispatch
# 3) We can extract JT metadata from NT, so that kernels can be reused
class NestedTensor(torch.Tensor):
storage: torch.Tensor
nested_sizes: torch.Tensor
nested_strides: torch.Tensor
offsets: torch.Tensor
is_jagged: bool
__torch_function__ = torch._C._disabled_torch_function_impl
@staticmethod
def __new__(cls, buffer, nested_sizes=None, offsets=None, *args, **kwargs):
# TODO: what metadata should we advertise on the wrapper tensor
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls, buffer.size(),
strides=buffer.stride(), storage_offset=buffer.storage_offset(),
# TODO: clone storage aliasing
dtype=buffer.dtype, layout=buffer.layout,
device=buffer.device, requires_grad=kwargs.get("requires_grad", False)
)
# Just error instead if something requires grad?
r.buffer = buffer.detach() if r.requires_grad else buffer
return r
def __init__(self, buffer, nested_sizes=None, offsets=None):
super().__init__()
if nested_sizes is None:
assert offsets is not None
# Assume that we're dealing with NT with shape (B, *, D)
self.is_jagged = True
self.is_jagged = False
if offsets is not None:
assert not (torch.is_floating_point(offsets) or torch.is_complex(offsets))
self.nested_sizes = nested_sizes
self._offsets = offsets
def __repr__(self):
# What should this print out?
return super().__repr__(tensor_contents=f"{self.storage}")
@functools.cached_property
def nested_strides(self):
shifted_sizes = torch.roll(self.nested_sizes, shifts=1, dims=1)
shifted_sizes[:, 0] = 1
return torch.cumprod(shifted_sizes, dim=1)
@functools.cached_property
def offsets(self):
if self._offsets is not None:
return self._offsets
sizes = torch.prod(self.nested_sizes, dim=1)
shifted_sizes = torch.roll(sizes, shifts=1, dims=0)
shifted_sizes[0] = 0
return torch.cumsum(shifted_sizes, dim=0)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
num_tensor_args = sum(isinstance(t, torch.Tensor) for t in args)
# This path is for eager only. Once inductor is able to understand NT
# torch.compile, we would be telling aot autograd not to desugar and
# thus, we should not reach __torch_dispatch__. However, if we did want
# to compile for the short term, we could tell aot dispatch to desugar
# anyway. We may need to wrap NT ops in custom ops, so that the fx
# graph ends up with only plain tensors.
# 0) Assert None of the inputs are in JaggedTensor layout
# For now, require that the user ALWAYS extract their JT metadata out
assert not any(isinstance(x, NestedTensor) and x.is_jagged for x in args + tuple(kwargs.values()))
# 1) Do pointwise directly on buffer, calling pointwise ops does not
# materialize offsets or strides
if torch.Tag.pointwise in func.tags:
if num_tensor_args == 1:
# Assume first argument is NT, and returns a single output
return NestedTensor(func(args[0].buffer), args[0].nested_sizes)
elif num_tensor_args == 2:
# NYI: Check the two nested tensors have the same structure
# or have broadcastable shapes
pass
# 2) Fallback to using C++ nested tensor kernels
def to_cpp_nt(x):
return py_to_cpp(x) if isinstance(x, NestedTensor) else x
def from_cpp_nt(x):
return cpp_to_py(x) if isinstance(x, torch.Tensor) and x.is_nested else x
return tree_map(from_cpp_nt, func(*tree_map(to_cpp_nt, args), **tree_map(to_cpp_nt, kwargs)))
def py_to_cpp(nt: NestedTensor) -> torch.Tensor:
return torch._nested_view_from_buffer(nt.buffer, nt.nested_sizes, nt.nested_strides, nt.offsets)
def cpp_to_py(nt: torch.Tensor) -> NestedTensor:
return NestedTensor(nt.values(), nt._nested_tensor_size(), nt._nested_tensor_storage_offsets())
tensors = [
torch.tensor([[1., 2.], [1., 2.]]),
torch.tensor([[3., 4., 5.], [3., 4., 5.]]),
torch.tensor([[6., 7.], [6., 7.]])
]
cpp_nt = torch.nested.nested_tensor(tensors)
# We could've passed in offsets, but just to show that we don't have to here.
nt = NestedTensor(cpp_nt.values(), cpp_nt._nested_tensor_size())
# 1) When doing pointwise we stay in python land
out: NestedTensor = torch.sin(nt)
assert torch.allclose(nt.buffer.sin(), out.buffer)
# 2) Utilize existing C++ NestedTensor kernels
out: NestedTensor = nt.bmm(nt.transpose(dim0=-1, dim1=-2))
# 3) Call into JT kernels (B, *, D) case
tensors = [
torch.tensor([[1., 2.], [1., 2.]]),
torch.tensor([[3., 4.], [3., 4.], [3., 4.], [3., 4.]]),
torch.tensor([[6., 7.], [3., 4.], [6., 7.]])
]
cpp_nt = torch.nested.nested_tensor(tensors)
# Can we improve the UX here?
nt = NestedTensor(cpp_nt.values(), None, cpp_nt._nested_tensor_storage_offsets())
# Actually find a kernel to call into?
print(nt.buffer, nt.offsets)
# 4) [optional] Try to desugar into a plain-tensor-only fx graph so we can compile
#
# Not doing this now because compiling is not high-pri, but the way we'd do this
# is to hide the NT operations behind a plain tensor aten op
# def get_op(name, ns):
# return getattr(getattr(torch.ops, ns), name).default
# nested_ns = "nested"
# lib = Library(nested_ns, "FRAGMENT")
# def nested_tensor_foo():
# pass
# lib.define("foo(Tensor storage, Tensor offsets) -> Tensor")
# lib.impl("foo", nested_tensor_foo, "CPU")
# op = get_op("foo", nested_ns)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment