Skip to content

Instantly share code, notes, and snippets.

@soulitzer
Last active July 19, 2023 15:48
Show Gist options
  • Save soulitzer/e60116e2820d505f1a99a21a692b7452 to your computer and use it in GitHub Desktop.
Save soulitzer/e60116e2820d505f1a99a21a692b7452 to your computer and use it in GitHub Desktop.
NestedTensor __torch_dispatch__ wrapper tensor subclass POC with automatic dispatching to custom kernel
import torch
from torch.library import Library
from typing import List
from torch.utils._pytree import tree_map
import torch.nn.functional as F
import functools
import numpy as np
# NestedTensor __torch_dispatch__ wrapper tensor subclass POC
#
# - The __torch_dispatch__ handles pointwise ops entirely in python
# - Falls back to cpp NT kernels by converts args to cpp NT and then redispatch
# - We can extract JT metadata from NT, so that kernels can be reused
# - [new] some ops like torch.mm can be done easily for jagged layout
# - [new] automatically dispatch to custom kernels
# - [new] custom kernels can support autograd
# Other improvements to do:
# - unchecked version of _nested_view_from_buffer() that just shoves the metadata in
# - introduce a jagged layout option to torch.layout and utilize that in NestedTensor
# - for all NT factory functions, we'll have to return a NestedTensor instead
UNTAGGED_POINTWISE_OPS = set({
torch.ops.aten.silu.default,
torch.ops.aten.detach.default,
})
# A decomposition table specific to the NestedTensor subclass
nt_decomp_table = {}
# see (4) below for how this is used
def register_nt_decomp(op):
def decorator(fn):
nt_decomp_table[op] = fn
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)
return wrapper
return decorator
class NestedTensor(torch.Tensor):
storage: torch.Tensor
nested_sizes: torch.Tensor
nested_strides: torch.Tensor
offsets: torch.Tensor
is_jagged: bool
__torch_function__ = torch._C._disabled_torch_function_impl
@staticmethod
def __new__(cls, buffer, nested_sizes=None, offsets=None, *args, **kwargs):
# TODO: what metadata should we advertise on the wrapper tensor
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls, buffer.size(),
strides=buffer.stride(), storage_offset=buffer.storage_offset(),
# TODO: clone storage aliasing
dtype=buffer.dtype, layout=buffer.layout,
device=buffer.device, requires_grad=kwargs.get("requires_grad", False)
)
# Just error instead if something requires grad?
r.buffer = buffer.detach() if r.requires_grad else buffer
return r
def __init__(self, buffer, nested_sizes=None, offsets=None, *args, **kwargs):
super().__init__()
assert not isinstance(buffer, NestedTensor)
if nested_sizes is None:
assert offsets is not None
assert buffer.ndim == 2
# Assume that we're dealing with NT with shape (B, *, D)
self._offsets = offsets
return
assert buffer.ndim == 1
if offsets is not None:
assert not (torch.is_floating_point(offsets) or torch.is_complex(offsets))
self.nested_sizes = nested_sizes
self._offsets = offsets
def __repr__(self):
# What should this print out?
return super().__repr__(tensor_contents=f"{self.storage}")
@functools.cached_property
def is_jagged(self):
return self.buffer.ndim == 2
@functools.cached_property
def nested_strides(self):
shifted_sizes = torch.roll(self.nested_sizes, shifts=1, dims=1)
shifted_sizes[:, 0] = 1
return torch.cumprod(shifted_sizes, dim=1)
@functools.cached_property
def offsets(self):
if self._offsets is not None:
return self._offsets
sizes = torch.prod(self.nested_sizes, dim=1)
shifted_sizes = torch.roll(sizes, shifts=1, dims=0)
shifted_sizes[0] = 0
return torch.cumsum(shifted_sizes, dim=0)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# This path is for eager only. Once inductor is able to understand NT
# torch.compile, we would be telling aot autograd not to desugar and
# thus, we should not reach __torch_dispatch__. However, if we did want
# to compile for the short term, we could tell aot dispatch to desugar
# anyway. We may need to wrap NT ops in custom ops, so that the fx
# graph ends up with only plain tensors.
kwargs = {} if kwargs is None else kwargs
num_tensor_args = sum(isinstance(t, torch.Tensor) for t in args)
def is_jagged(x):
return isinstance(x, NestedTensor) and x.is_jagged
def is_strided(x):
return isinstance(x, NestedTensor) and not x.is_jagged
if any(isinstance(x, NestedTensor) for x in kwargs.values()):
raise NotImplementedError("NestedTensor in kwargs is not implemented")
any_jagged = any(is_jagged(x) for x in args)
any_strided = any(is_strided(x) for x in args)
if any_jagged and any_strided:
# 0) Do not support mixed jagged strided layout operations
raise NotImplementedError("Mixed jagged strided operations is not implemented")
# Note: Desugurable ops
#
# Operations for nested tensor are desugarable if they can desugar into
# operations directly on the nested tensor's buffer.
if any_jagged:
# 1) desugarable operations for jagged layout only, e.g. mm
assert not any_strided
# TODO: we can split this out instead of doing a bunch of if-else
if func == torch.ops.aten.mm.default:
# (B, *, D) @ (D, K)
assert is_jagged(args[0])
assert isinstance(args[1], torch.Tensor)
# Do the dimension checking here to print out a nice message?
return NestedTensor(func(args[0].buffer, args[1]), None, args[0].offsets)
if func == torch.ops.aten.split.Tensor:
# TODO: in a view situation buffer aliases buffer
# but this view relationship is not reflected on the wrapper subclass
# only allow functional operations
return NestedTensor(func(args[0].buffer, args[1]))
# Dispatch to any registered jagged tensor kernels
if func in nt_decomp_table:
return nt_decomp_table[func](*args, **kwargs)
# 2) Pointwise ops are desugarable for jagged and strided layout
# Note that for strided format, there's no need to materialize
# offsets or strides
if torch.Tag.pointwise in func.tags or func in UNTAGGED_POINTWISE_OPS:
if num_tensor_args == 1:
# Assume first argument is NT, and returns a single output
if is_jagged(args[0]):
return NestedTensor(func(args[0].buffer), None, args[0].offsets)
else:
return NestedTensor(func(args[0].buffer), args[0].nested_sizes)
elif num_tensor_args == 2:
# NYI: Check the two nested tensors have the same structure
# or have broadcastable shapes
pass
if any_jagged:
# Users should convert to strided layout, padded layout?
raise NotImplementedError(f"{func} not supported for nested tensor in Jagged layout")
# 3) Fallback to using C++ nested tensor kernels for strided layout
def to_cpp_nt(x):
return py_to_cpp(x) if isinstance(x, NestedTensor) else x
def from_cpp_nt(x):
return cpp_to_py(x) if isinstance(x, torch.Tensor) and x.is_nested else x
return tree_map(from_cpp_nt, func(*tree_map(to_cpp_nt, args), **tree_map(to_cpp_nt, kwargs)))
def py_to_cpp(nt: NestedTensor) -> torch.Tensor:
return torch._nested_view_from_buffer(nt.buffer, nt.nested_sizes, nt.nested_strides, nt.offsets)
def cpp_to_py(nt: torch.Tensor) -> NestedTensor:
return NestedTensor(nt.values(), nt._nested_tensor_size(), nt._nested_tensor_storage_offsets())
tensors = [
torch.tensor([[1., 2.], [1., 2.]]),
torch.tensor([[3., 4., 5.], [3., 4., 5.]]),
torch.tensor([[6., 7.], [6., 7.]])
]
cpp_nt = torch.nested.nested_tensor(tensors)
# We could've passed in offsets, but just to show that we don't have to here.
nt = NestedTensor(cpp_nt.values(), cpp_nt._nested_tensor_size())
# 1) When doing pointwise we stay in python land
out: NestedTensor = torch.sin(nt)
assert torch.allclose(nt.buffer.sin(), out.buffer)
# 2) Utilize existing C++ NestedTensor kernels
out: NestedTensor = nt.bmm(nt.transpose(dim0=-1, dim1=-2))
# 3) Call into JT kernels (B, *, D) case
tensors = [
torch.tensor([[1., 2.], [1., 2.]]),
torch.tensor([[3., 4.], [3., 4.], [3., 4.], [3., 4.]]),
torch.tensor([[6., 7.], [3., 4.], [6., 7.]])
]
cpp_nt = torch.nested.nested_tensor(tensors)
# Can we improve the UX here?
nt = NestedTensor(torch.reshape(cpp_nt.values(), (9, 2)), None, cpp_nt._nested_tensor_storage_offsets())
# Do a mixed JT, plain tensor operation, these just desugar into plain torch ops
t = torch.ones(2, 2)
F.silu(torch.mm(nt, t))
# 4a) Custom operator registration and automatic dispatch
#
# For the ops that cannot desugar into plain torch ops, we may need a custom
# operators. Here we demonstrate with a impl that calls into numpy
#
# The user would need to do several things here:
#
# 1) register dummy custom ops, e.g. torch.ops.nested.foo
# 2) register decomp that would decompose with key torch.ops.nested.foo
# so that during __torch_dispatch__ we would check this table and
# do the proper {un,}wrapping to call into the underlying
# foo_impl(t, t_offsets) kernel which takes the two JaggedTensor tensors
# 3) the registration would need to be done for forward and backward
# separately and glued together with a custom autograd Function which
# would then be registered at the autograd key.
#
# Models already do (3) with their triton kernels, we just need to
# register dummy custom kernels that represent the forward and backward
# then we call them from our decomps.
lib = Library('custom_ns', 'FRAGMENT')
# Custom operators that take and return nts
lib.define("custom_op_fw(Tensor x) -> Tensor")
lib.define("custom_op_bw(Tensor grad_out, Tensor saved_x) -> Tensor")
# Register some dummy kernels that never really get called (see below)
def custom_op_fw(x: NestedTensor):
raise AssertionError("internal assert: we don't expect to reach here")
def custom_op_bw(gx: torch.Tensor, x: NestedTensor):
raise AssertionError("internal assert: we don't expect to reach here")
lib.impl("custom_op_fw", custom_op_fw, "CPU")
lib.impl("custom_op_bw", custom_op_bw, "CPU")
# In eager we would never reach the dummy kernels because we would've
# decompose using the decompositions registered below
# User does their own unwrapping and wrapping; otherwise it would be hard for
# us to figure out how to rewrap unless the user provided extra information in
# the signature about which outputs are nested tensor. Even then this way
# is the most flexible for the user since you can now return multiple NTs
# that share the same offsets tensor.
@register_nt_decomp(torch.ops.custom_ns.custom_op_fw.default)
def custom_op_fw_impl(x: NestedTensor):
return NestedTensor(
torch.tensor(np.sin(x.buffer.cpu().numpy())),
None,
x.offsets
)
@register_nt_decomp(torch.ops.custom_ns.custom_op_bw.default)
def custom_op_bw_impl(grad_out, saved_x):
# The gradient might be jagged
# assert not isinstance(NestedTensor, grad_out)
return NestedTensor(
grad_out.buffer * torch.tensor(np.cos(saved_x.buffer.cpu().numpy())),
None,
saved_x.offsets
)
# Tie the forward and backward together in a custom autograd Function
# Saving tensors for backward, etc should be
class CustomOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
with torch._C._AutoDispatchBelowAutograd():
return torch.ops.custom_ns.custom_op_fw(x)
@staticmethod
def backward(ctx, gx):
x, = ctx.saved_tensors
return torch.ops.custom_ns.custom_op_bw(gx, x)
def custom_op_fw_autograd(x):
return CustomOp.apply(x)
lib.impl("custom_op_fw", custom_op_fw_autograd, "Autograd")
nt = NestedTensor(
torch.reshape(cpp_nt.values(), (9, 2)),
None,
cpp_nt._nested_tensor_storage_offsets(),
requires_grad=True
)
out = torch.ops.custom_ns.custom_op_fw(nt)
assert torch.allclose(nt.buffer.sin(), out.buffer)
# 4b) Autograd support
# This extra helpr to convert between Jagged format NT to its buffer in
# a differentiable way because I didn't feel like writing the sum op
# What does the sum op even mean here?
# This probably supports both Jagged layout and strided layout
# We can wrap this in a custom op if we wanted to
class ViewBufferFromJagged(torch.autograd.Function):
@staticmethod
def forward(ctx, x: NestedTensor):
# Jagged -> plain
assert x.is_jagged
ctx.save_for_backward(x.offsets)
# This is technically aliased to input, but autograd function doesn't know!
return x.buffer
@staticmethod
def backward(ctx, gO):
offsets, = ctx.saved_tensors
# plain -> jagged
return NestedTensor(gO, None, offsets)
buffer = ViewBufferFromJagged.apply(out)
buffer.sum().backward()
assert torch.allclose(nt.buffer.cos(), nt.grad.buffer)
assert torch.allclose(nt.offsets, nt.grad.offsets)
# 5) [optional] Support torch compile
#
# Not doing this now because compiling is not high-pri in the short term, once
# inductor is able to understand
#
# We need some more things here to support torch.compile
# - meta functions for the dummy operators
# - hide more operations behind custom ops, so that we can desugar into a
# plain-tensor-only fx graph
# def custom_op_fw_meta(x):
# # What do we do here?
# return x.sin()
# lib.impl("custom_op_fw", custom_op_fw_meta, "Meta")
# x = torch.randn([], requires_grad=True)
# y = torch.compile(torch.ops.custom_ns.custom_op, fullgraph=True, backend='aot_eager')(x)
# gx, = torch.autograd.grad(y, x)
# assert torch.allclose(gx, x.cos())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment