Skip to content

Instantly share code, notes, and snippets.

@andylolu2
Created April 24, 2024 21:31
Show Gist options
  • Save andylolu2/629098cf041108a77e37d2c8b6e91467 to your computer and use it in GitHub Desktop.
Save andylolu2/629098cf041108a77e37d2c8b6e91467 to your computer and use it in GitHub Desktop.
Better indexing with PyTorch that doesn't work
import inspect
from typing import Set
from functools import partial
import torch
from torch import Tensor
class ConstraintTrackingTensor(Tensor):
_constraints: Set[int]
@staticmethod
def add_constraint(tensor, size):
if isinstance(tensor, ConstraintTrackingTensor):
if hasattr(tensor, "_constraints"):
tensor._constraints.add(size)
else:
tensor._constraints = {size}
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
args_l = list(args)
if func.__name__ == "__getitem__":
if isinstance(args_l[1], ConstraintTrackingTensor):
ConstraintTrackingTensor.add_constraint(args_l[1], args_l[0].shape[0])
elif isinstance(args_l[1], tuple) and any(
isinstance(i, ConstraintTrackingTensor) for i in args_l[1]
):
for i, (size, index) in enumerate(zip(args_l[0].shape, args_l[1])):
ConstraintTrackingTensor.add_constraint(index, size)
if isinstance(args_l[0], ConstraintTrackingTensor):
args_l[0] = torch.tensor(args_l[0])
return torch.tensor(
super().__torch_function__(func, types, tuple(args_l), kwargs)
)
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, tuple(args_l), kwargs)
def infer_output_shape(f):
n_args = len(inspect.signature(f).parameters)
dummy_indices = [ConstraintTrackingTensor(torch.tensor(0)) for _ in range(n_args)]
out = f(*dummy_indices)
assert out.ndim == 0
constraints = [getattr(idx, "_constraints", set()) for idx in dummy_indices]
assert all(len(constraint) == 1 for constraint in constraints)
return tuple(next(iter(constraint)) for constraint in constraints)
def ein_arr(f):
output_shape = infer_output_shape(f)
indices = []
for i, size in enumerate(output_shape):
index = torch.arange(size)
# (size,) -> (1, 1, ..., size, ..., 1, 1)
index = index.view(*((1,)*i), size, *((1,)*(len(output_shape)-1-i))).broadcast_to(*output_shape)
indices.append(index)
for _ in output_shape: # tensor up the function
f = torch.vmap(f)
return f(*indices)
x = torch.randn(5, 10)
y = torch.randn(5, 10)
def f(i, j):
return torch.dot(x[i], y[j])
print(infer_output_shape(f))
# (5, 5)
print(ein_arr(f))
# RuntimeError: vmap: It looks like you're calling .item() on a Tensor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment