Last active
July 19, 2023 15:48
-
-
Save soulitzer/e60116e2820d505f1a99a21a692b7452 to your computer and use it in GitHub Desktop.
NestedTensor __torch_dispatch__ wrapper tensor subclass POC with automatic dispatching to custom kernel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.library import Library | |
from typing import List | |
from torch.utils._pytree import tree_map | |
import torch.nn.functional as F | |
import functools | |
import numpy as np | |
# NestedTensor __torch_dispatch__ wrapper tensor subclass POC | |
# | |
# - The __torch_dispatch__ handles pointwise ops entirely in python | |
# - Falls back to cpp NT kernels by converts args to cpp NT and then redispatch | |
# - We can extract JT metadata from NT, so that kernels can be reused | |
# - [new] some ops like torch.mm can be done easily for jagged layout | |
# - [new] automatically dispatch to custom kernels | |
# - [new] custom kernels can support autograd | |
# Other improvements to do: | |
# - unchecked version of _nested_view_from_buffer() that just shoves the metadata in | |
# - introduce a jagged layout option to torch.layout and utilize that in NestedTensor | |
# - for all NT factory functions, we'll have to return a NestedTensor instead | |
UNTAGGED_POINTWISE_OPS = set({ | |
torch.ops.aten.silu.default, | |
torch.ops.aten.detach.default, | |
}) | |
# A decomposition table specific to the NestedTensor subclass | |
nt_decomp_table = {} | |
# see (4) below for how this is used | |
def register_nt_decomp(op): | |
def decorator(fn): | |
nt_decomp_table[op] = fn | |
def wrapper(*args, **kwargs): | |
return fn(*args, **kwargs) | |
return wrapper | |
return decorator | |
class NestedTensor(torch.Tensor): | |
storage: torch.Tensor | |
nested_sizes: torch.Tensor | |
nested_strides: torch.Tensor | |
offsets: torch.Tensor | |
is_jagged: bool | |
__torch_function__ = torch._C._disabled_torch_function_impl | |
@staticmethod | |
def __new__(cls, buffer, nested_sizes=None, offsets=None, *args, **kwargs): | |
# TODO: what metadata should we advertise on the wrapper tensor | |
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] | |
cls, buffer.size(), | |
strides=buffer.stride(), storage_offset=buffer.storage_offset(), | |
# TODO: clone storage aliasing | |
dtype=buffer.dtype, layout=buffer.layout, | |
device=buffer.device, requires_grad=kwargs.get("requires_grad", False) | |
) | |
# Just error instead if something requires grad? | |
r.buffer = buffer.detach() if r.requires_grad else buffer | |
return r | |
def __init__(self, buffer, nested_sizes=None, offsets=None, *args, **kwargs): | |
super().__init__() | |
assert not isinstance(buffer, NestedTensor) | |
if nested_sizes is None: | |
assert offsets is not None | |
assert buffer.ndim == 2 | |
# Assume that we're dealing with NT with shape (B, *, D) | |
self._offsets = offsets | |
return | |
assert buffer.ndim == 1 | |
if offsets is not None: | |
assert not (torch.is_floating_point(offsets) or torch.is_complex(offsets)) | |
self.nested_sizes = nested_sizes | |
self._offsets = offsets | |
def __repr__(self): | |
# What should this print out? | |
return super().__repr__(tensor_contents=f"{self.storage}") | |
@functools.cached_property | |
def is_jagged(self): | |
return self.buffer.ndim == 2 | |
@functools.cached_property | |
def nested_strides(self): | |
shifted_sizes = torch.roll(self.nested_sizes, shifts=1, dims=1) | |
shifted_sizes[:, 0] = 1 | |
return torch.cumprod(shifted_sizes, dim=1) | |
@functools.cached_property | |
def offsets(self): | |
if self._offsets is not None: | |
return self._offsets | |
sizes = torch.prod(self.nested_sizes, dim=1) | |
shifted_sizes = torch.roll(sizes, shifts=1, dims=0) | |
shifted_sizes[0] = 0 | |
return torch.cumsum(shifted_sizes, dim=0) | |
@classmethod | |
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | |
# This path is for eager only. Once inductor is able to understand NT | |
# torch.compile, we would be telling aot autograd not to desugar and | |
# thus, we should not reach __torch_dispatch__. However, if we did want | |
# to compile for the short term, we could tell aot dispatch to desugar | |
# anyway. We may need to wrap NT ops in custom ops, so that the fx | |
# graph ends up with only plain tensors. | |
kwargs = {} if kwargs is None else kwargs | |
num_tensor_args = sum(isinstance(t, torch.Tensor) for t in args) | |
def is_jagged(x): | |
return isinstance(x, NestedTensor) and x.is_jagged | |
def is_strided(x): | |
return isinstance(x, NestedTensor) and not x.is_jagged | |
if any(isinstance(x, NestedTensor) for x in kwargs.values()): | |
raise NotImplementedError("NestedTensor in kwargs is not implemented") | |
any_jagged = any(is_jagged(x) for x in args) | |
any_strided = any(is_strided(x) for x in args) | |
if any_jagged and any_strided: | |
# 0) Do not support mixed jagged strided layout operations | |
raise NotImplementedError("Mixed jagged strided operations is not implemented") | |
# Note: Desugurable ops | |
# | |
# Operations for nested tensor are desugarable if they can desugar into | |
# operations directly on the nested tensor's buffer. | |
if any_jagged: | |
# 1) desugarable operations for jagged layout only, e.g. mm | |
assert not any_strided | |
# TODO: we can split this out instead of doing a bunch of if-else | |
if func == torch.ops.aten.mm.default: | |
# (B, *, D) @ (D, K) | |
assert is_jagged(args[0]) | |
assert isinstance(args[1], torch.Tensor) | |
# Do the dimension checking here to print out a nice message? | |
return NestedTensor(func(args[0].buffer, args[1]), None, args[0].offsets) | |
if func == torch.ops.aten.split.Tensor: | |
# TODO: in a view situation buffer aliases buffer | |
# but this view relationship is not reflected on the wrapper subclass | |
# only allow functional operations | |
return NestedTensor(func(args[0].buffer, args[1])) | |
# Dispatch to any registered jagged tensor kernels | |
if func in nt_decomp_table: | |
return nt_decomp_table[func](*args, **kwargs) | |
# 2) Pointwise ops are desugarable for jagged and strided layout | |
# Note that for strided format, there's no need to materialize | |
# offsets or strides | |
if torch.Tag.pointwise in func.tags or func in UNTAGGED_POINTWISE_OPS: | |
if num_tensor_args == 1: | |
# Assume first argument is NT, and returns a single output | |
if is_jagged(args[0]): | |
return NestedTensor(func(args[0].buffer), None, args[0].offsets) | |
else: | |
return NestedTensor(func(args[0].buffer), args[0].nested_sizes) | |
elif num_tensor_args == 2: | |
# NYI: Check the two nested tensors have the same structure | |
# or have broadcastable shapes | |
pass | |
if any_jagged: | |
# Users should convert to strided layout, padded layout? | |
raise NotImplementedError(f"{func} not supported for nested tensor in Jagged layout") | |
# 3) Fallback to using C++ nested tensor kernels for strided layout | |
def to_cpp_nt(x): | |
return py_to_cpp(x) if isinstance(x, NestedTensor) else x | |
def from_cpp_nt(x): | |
return cpp_to_py(x) if isinstance(x, torch.Tensor) and x.is_nested else x | |
return tree_map(from_cpp_nt, func(*tree_map(to_cpp_nt, args), **tree_map(to_cpp_nt, kwargs))) | |
def py_to_cpp(nt: NestedTensor) -> torch.Tensor: | |
return torch._nested_view_from_buffer(nt.buffer, nt.nested_sizes, nt.nested_strides, nt.offsets) | |
def cpp_to_py(nt: torch.Tensor) -> NestedTensor: | |
return NestedTensor(nt.values(), nt._nested_tensor_size(), nt._nested_tensor_storage_offsets()) | |
tensors = [ | |
torch.tensor([[1., 2.], [1., 2.]]), | |
torch.tensor([[3., 4., 5.], [3., 4., 5.]]), | |
torch.tensor([[6., 7.], [6., 7.]]) | |
] | |
cpp_nt = torch.nested.nested_tensor(tensors) | |
# We could've passed in offsets, but just to show that we don't have to here. | |
nt = NestedTensor(cpp_nt.values(), cpp_nt._nested_tensor_size()) | |
# 1) When doing pointwise we stay in python land | |
out: NestedTensor = torch.sin(nt) | |
assert torch.allclose(nt.buffer.sin(), out.buffer) | |
# 2) Utilize existing C++ NestedTensor kernels | |
out: NestedTensor = nt.bmm(nt.transpose(dim0=-1, dim1=-2)) | |
# 3) Call into JT kernels (B, *, D) case | |
tensors = [ | |
torch.tensor([[1., 2.], [1., 2.]]), | |
torch.tensor([[3., 4.], [3., 4.], [3., 4.], [3., 4.]]), | |
torch.tensor([[6., 7.], [3., 4.], [6., 7.]]) | |
] | |
cpp_nt = torch.nested.nested_tensor(tensors) | |
# Can we improve the UX here? | |
nt = NestedTensor(torch.reshape(cpp_nt.values(), (9, 2)), None, cpp_nt._nested_tensor_storage_offsets()) | |
# Do a mixed JT, plain tensor operation, these just desugar into plain torch ops | |
t = torch.ones(2, 2) | |
F.silu(torch.mm(nt, t)) | |
# 4a) Custom operator registration and automatic dispatch | |
# | |
# For the ops that cannot desugar into plain torch ops, we may need a custom | |
# operators. Here we demonstrate with a impl that calls into numpy | |
# | |
# The user would need to do several things here: | |
# | |
# 1) register dummy custom ops, e.g. torch.ops.nested.foo | |
# 2) register decomp that would decompose with key torch.ops.nested.foo | |
# so that during __torch_dispatch__ we would check this table and | |
# do the proper {un,}wrapping to call into the underlying | |
# foo_impl(t, t_offsets) kernel which takes the two JaggedTensor tensors | |
# 3) the registration would need to be done for forward and backward | |
# separately and glued together with a custom autograd Function which | |
# would then be registered at the autograd key. | |
# | |
# Models already do (3) with their triton kernels, we just need to | |
# register dummy custom kernels that represent the forward and backward | |
# then we call them from our decomps. | |
lib = Library('custom_ns', 'FRAGMENT') | |
# Custom operators that take and return nts | |
lib.define("custom_op_fw(Tensor x) -> Tensor") | |
lib.define("custom_op_bw(Tensor grad_out, Tensor saved_x) -> Tensor") | |
# Register some dummy kernels that never really get called (see below) | |
def custom_op_fw(x: NestedTensor): | |
raise AssertionError("internal assert: we don't expect to reach here") | |
def custom_op_bw(gx: torch.Tensor, x: NestedTensor): | |
raise AssertionError("internal assert: we don't expect to reach here") | |
lib.impl("custom_op_fw", custom_op_fw, "CPU") | |
lib.impl("custom_op_bw", custom_op_bw, "CPU") | |
# In eager we would never reach the dummy kernels because we would've | |
# decompose using the decompositions registered below | |
# User does their own unwrapping and wrapping; otherwise it would be hard for | |
# us to figure out how to rewrap unless the user provided extra information in | |
# the signature about which outputs are nested tensor. Even then this way | |
# is the most flexible for the user since you can now return multiple NTs | |
# that share the same offsets tensor. | |
@register_nt_decomp(torch.ops.custom_ns.custom_op_fw.default) | |
def custom_op_fw_impl(x: NestedTensor): | |
return NestedTensor( | |
torch.tensor(np.sin(x.buffer.cpu().numpy())), | |
None, | |
x.offsets | |
) | |
@register_nt_decomp(torch.ops.custom_ns.custom_op_bw.default) | |
def custom_op_bw_impl(grad_out, saved_x): | |
# The gradient might be jagged | |
# assert not isinstance(NestedTensor, grad_out) | |
return NestedTensor( | |
grad_out.buffer * torch.tensor(np.cos(saved_x.buffer.cpu().numpy())), | |
None, | |
saved_x.offsets | |
) | |
# Tie the forward and backward together in a custom autograd Function | |
# Saving tensors for backward, etc should be | |
class CustomOp(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, x): | |
ctx.save_for_backward(x) | |
with torch._C._AutoDispatchBelowAutograd(): | |
return torch.ops.custom_ns.custom_op_fw(x) | |
@staticmethod | |
def backward(ctx, gx): | |
x, = ctx.saved_tensors | |
return torch.ops.custom_ns.custom_op_bw(gx, x) | |
def custom_op_fw_autograd(x): | |
return CustomOp.apply(x) | |
lib.impl("custom_op_fw", custom_op_fw_autograd, "Autograd") | |
nt = NestedTensor( | |
torch.reshape(cpp_nt.values(), (9, 2)), | |
None, | |
cpp_nt._nested_tensor_storage_offsets(), | |
requires_grad=True | |
) | |
out = torch.ops.custom_ns.custom_op_fw(nt) | |
assert torch.allclose(nt.buffer.sin(), out.buffer) | |
# 4b) Autograd support | |
# This extra helpr to convert between Jagged format NT to its buffer in | |
# a differentiable way because I didn't feel like writing the sum op | |
# What does the sum op even mean here? | |
# This probably supports both Jagged layout and strided layout | |
# We can wrap this in a custom op if we wanted to | |
class ViewBufferFromJagged(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, x: NestedTensor): | |
# Jagged -> plain | |
assert x.is_jagged | |
ctx.save_for_backward(x.offsets) | |
# This is technically aliased to input, but autograd function doesn't know! | |
return x.buffer | |
@staticmethod | |
def backward(ctx, gO): | |
offsets, = ctx.saved_tensors | |
# plain -> jagged | |
return NestedTensor(gO, None, offsets) | |
buffer = ViewBufferFromJagged.apply(out) | |
buffer.sum().backward() | |
assert torch.allclose(nt.buffer.cos(), nt.grad.buffer) | |
assert torch.allclose(nt.offsets, nt.grad.offsets) | |
# 5) [optional] Support torch compile | |
# | |
# Not doing this now because compiling is not high-pri in the short term, once | |
# inductor is able to understand | |
# | |
# We need some more things here to support torch.compile | |
# - meta functions for the dummy operators | |
# - hide more operations behind custom ops, so that we can desugar into a | |
# plain-tensor-only fx graph | |
# def custom_op_fw_meta(x): | |
# # What do we do here? | |
# return x.sin() | |
# lib.impl("custom_op_fw", custom_op_fw_meta, "Meta") | |
# x = torch.randn([], requires_grad=True) | |
# y = torch.compile(torch.ops.custom_ns.custom_op, fullgraph=True, backend='aot_eager')(x) | |
# gx, = torch.autograd.grad(y, x) | |
# assert torch.allclose(gx, x.cos()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment