Created
July 17, 2023 15:24
-
-
Save soulitzer/d1c8960030e85668f58157efe6e55b86 to your computer and use it in GitHub Desktop.
NestedTensor python torch dispatch wrapper subclass
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.library import Library | |
from typing import List | |
from torch.utils._pytree import tree_map | |
import functools | |
# NestedTensor __torch_dispatch__ wrapper tensor subclass POC | |
# | |
# 1) The __torch_dispatch__ handles pointwise ops entirely in python | |
# 2) Falls back to cpp NT kernels by converts args to cpp NT and then redispatch | |
# 3) We can extract JT metadata from NT, so that kernels can be reused | |
class NestedTensor(torch.Tensor): | |
storage: torch.Tensor | |
nested_sizes: torch.Tensor | |
nested_strides: torch.Tensor | |
offsets: torch.Tensor | |
is_jagged: bool | |
__torch_function__ = torch._C._disabled_torch_function_impl | |
@staticmethod | |
def __new__(cls, buffer, nested_sizes=None, offsets=None, *args, **kwargs): | |
# TODO: what metadata should we advertise on the wrapper tensor | |
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] | |
cls, buffer.size(), | |
strides=buffer.stride(), storage_offset=buffer.storage_offset(), | |
# TODO: clone storage aliasing | |
dtype=buffer.dtype, layout=buffer.layout, | |
device=buffer.device, requires_grad=kwargs.get("requires_grad", False) | |
) | |
# Just error instead if something requires grad? | |
r.buffer = buffer.detach() if r.requires_grad else buffer | |
return r | |
def __init__(self, buffer, nested_sizes=None, offsets=None): | |
super().__init__() | |
if nested_sizes is None: | |
assert offsets is not None | |
# Assume that we're dealing with NT with shape (B, *, D) | |
self.is_jagged = True | |
self.is_jagged = False | |
if offsets is not None: | |
assert not (torch.is_floating_point(offsets) or torch.is_complex(offsets)) | |
self.nested_sizes = nested_sizes | |
self._offsets = offsets | |
def __repr__(self): | |
# What should this print out? | |
return super().__repr__(tensor_contents=f"{self.storage}") | |
@functools.cached_property | |
def nested_strides(self): | |
shifted_sizes = torch.roll(self.nested_sizes, shifts=1, dims=1) | |
shifted_sizes[:, 0] = 1 | |
return torch.cumprod(shifted_sizes, dim=1) | |
@functools.cached_property | |
def offsets(self): | |
if self._offsets is not None: | |
return self._offsets | |
sizes = torch.prod(self.nested_sizes, dim=1) | |
shifted_sizes = torch.roll(sizes, shifts=1, dims=0) | |
shifted_sizes[0] = 0 | |
return torch.cumsum(shifted_sizes, dim=0) | |
@classmethod | |
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | |
kwargs = {} if kwargs is None else kwargs | |
num_tensor_args = sum(isinstance(t, torch.Tensor) for t in args) | |
# This path is for eager only. Once inductor is able to understand NT | |
# torch.compile, we would be telling aot autograd not to desugar and | |
# thus, we should not reach __torch_dispatch__. However, if we did want | |
# to compile for the short term, we could tell aot dispatch to desugar | |
# anyway. We may need to wrap NT ops in custom ops, so that the fx | |
# graph ends up with only plain tensors. | |
# 0) Assert None of the inputs are in JaggedTensor layout | |
# For now, require that the user ALWAYS extract their JT metadata out | |
assert not any(isinstance(x, NestedTensor) and x.is_jagged for x in args + tuple(kwargs.values())) | |
# 1) Do pointwise directly on buffer, calling pointwise ops does not | |
# materialize offsets or strides | |
if torch.Tag.pointwise in func.tags: | |
if num_tensor_args == 1: | |
# Assume first argument is NT, and returns a single output | |
return NestedTensor(func(args[0].buffer), args[0].nested_sizes) | |
elif num_tensor_args == 2: | |
# NYI: Check the two nested tensors have the same structure | |
# or have broadcastable shapes | |
pass | |
# 2) Fallback to using C++ nested tensor kernels | |
def to_cpp_nt(x): | |
return py_to_cpp(x) if isinstance(x, NestedTensor) else x | |
def from_cpp_nt(x): | |
return cpp_to_py(x) if isinstance(x, torch.Tensor) and x.is_nested else x | |
return tree_map(from_cpp_nt, func(*tree_map(to_cpp_nt, args), **tree_map(to_cpp_nt, kwargs))) | |
def py_to_cpp(nt: NestedTensor) -> torch.Tensor: | |
return torch._nested_view_from_buffer(nt.buffer, nt.nested_sizes, nt.nested_strides, nt.offsets) | |
def cpp_to_py(nt: torch.Tensor) -> NestedTensor: | |
return NestedTensor(nt.values(), nt._nested_tensor_size(), nt._nested_tensor_storage_offsets()) | |
tensors = [ | |
torch.tensor([[1., 2.], [1., 2.]]), | |
torch.tensor([[3., 4., 5.], [3., 4., 5.]]), | |
torch.tensor([[6., 7.], [6., 7.]]) | |
] | |
cpp_nt = torch.nested.nested_tensor(tensors) | |
# We could've passed in offsets, but just to show that we don't have to here. | |
nt = NestedTensor(cpp_nt.values(), cpp_nt._nested_tensor_size()) | |
# 1) When doing pointwise we stay in python land | |
out: NestedTensor = torch.sin(nt) | |
assert torch.allclose(nt.buffer.sin(), out.buffer) | |
# 2) Utilize existing C++ NestedTensor kernels | |
out: NestedTensor = nt.bmm(nt.transpose(dim0=-1, dim1=-2)) | |
# 3) Call into JT kernels (B, *, D) case | |
tensors = [ | |
torch.tensor([[1., 2.], [1., 2.]]), | |
torch.tensor([[3., 4.], [3., 4.], [3., 4.], [3., 4.]]), | |
torch.tensor([[6., 7.], [3., 4.], [6., 7.]]) | |
] | |
cpp_nt = torch.nested.nested_tensor(tensors) | |
# Can we improve the UX here? | |
nt = NestedTensor(cpp_nt.values(), None, cpp_nt._nested_tensor_storage_offsets()) | |
# Actually find a kernel to call into? | |
print(nt.buffer, nt.offsets) | |
# 4) [optional] Try to desugar into a plain-tensor-only fx graph so we can compile | |
# | |
# Not doing this now because compiling is not high-pri, but the way we'd do this | |
# is to hide the NT operations behind a plain tensor aten op | |
# def get_op(name, ns): | |
# return getattr(getattr(torch.ops, ns), name).default | |
# nested_ns = "nested" | |
# lib = Library(nested_ns, "FRAGMENT") | |
# def nested_tensor_foo(): | |
# pass | |
# lib.define("foo(Tensor storage, Tensor offsets) -> Tensor") | |
# lib.impl("foo", nested_tensor_foo, "CPU") | |
# op = get_op("foo", nested_ns) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment